Skip to content

Commit 705ef0a

Browse files
authored
Merge branch 'main' into dk/cleanup-matprod_dest
2 parents a021822 + ef56ea2 commit 705ef0a

7 files changed

Lines changed: 55 additions & 18 deletions

File tree

src/SparseArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import Base: Matrix, Vector
1919
import LinearAlgebra: mul!, ldiv!, rdiv!, cholesky, adjoint!, diag, eigen, dot,
2020
issymmetric, istril, istriu, lu, tr, transpose!, tril!, triu!, isbanded, isdiag,
2121
cond, diagm, factorize, ishermitian, norm, opnorm, lmul!, rmul!, tril, triu,
22-
matprod_dest, generic_matvecmul!, generic_matmatmul!, copytrito!
22+
matprod_dest, generic_matvecmul!, generic_matmatmul!, generic_matmatmul_wrapper!, copytrito!
2323

2424
import Base: adjoint, argmin, argmax, Array, broadcast, circshift!, complex, Complex,
2525
conj, conj!, convert, copy, copy!, copyto!, count, diff, findall, findmax, findmin,

src/linalg.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, alpha, beta)
5353
_At_or_Ac_mul_B!(transpose, C, A, wrap(B, tB), alpha, beta)
5454
elseif tA_uc == 'C'
5555
_At_or_Ac_mul_B!(adjoint, C, A, wrap(B, tB), alpha, beta)
56-
elseif tA_uc in ('S', 'H') && tB_uc == 'N'
56+
elseif tA_uc in ('S', 'H')
5757
rangefun = isuppercase(tA) ? nzrangeup : nzrangelo
5858
diagop = tA_uc == 'S' ? identity : real
5959
odiagop = tA_uc == 'S' ? transpose : adjoint
6060
T = eltype(C)
61-
_mul!(rangefun, diagop, odiagop, C, A, B, T(alpha), T(beta))
61+
_mul!(rangefun, diagop, odiagop, C, A, wrap(B, tB), T(alpha), T(beta))
6262
else
6363
@stable_muladdmul LinearAlgebra._generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(alpha, beta))
6464
end
@@ -108,7 +108,7 @@ function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β)
108108
C
109109
end
110110

111-
Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, alpha::Number, beta::Number)
111+
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, alpha::Number, beta::Number, ::LinearAlgebra.BlasFlag.SyrkHerkGemm)
112112
transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint
113113
if tB == 'N'
114114
_spmul!(C, transA(A), B, alpha, beta)
@@ -119,6 +119,9 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB
119119
end
120120
return C
121121
end
122+
Base.@constprop :aggressive generic_matmatmul_wrapper!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, alpha::Number, beta::Number, @nospecialize(val)) =
123+
LinearAlgebra._generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
124+
122125
function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2, α::Number, β::Number)
123126
mX, nX = size(X)
124127
nX == size(A, 1) ||
@@ -1733,7 +1736,7 @@ function opnormestinv(A::AbstractSparseMatrixCSC{T}, t::Integer = min(2,maximum(
17331736
repeated = true
17341737
end
17351738
end
1736-
if !repeated
1739+
if !repeated && 2^(n-1) 2t #we need enough non-parallel ±1 vectors
17371740
saux2 = S[1:n,j]' * S_old[1:n,1:t]
17381741
if _any_abs_eq(saux2,n)
17391742
repeated = true

src/solvers/cholmod.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,11 +1287,9 @@ function size(F::Factor, i::Integer)
12871287
return 1
12881288
end
12891289
size(F::Factor) = (size(F, 1), size(F, 2))
1290-
axes(A::Union{Dense,Sparse,Factor}) = map(Base.OneTo, size(A))
12911290

12921291
IndexStyle(::Type{<:Dense}) = IndexLinear()
12931292

1294-
size(FC::FactorComponent, i::Integer) = size(FC.F, i)
12951293
size(FC::FactorComponent) = size(FC.F)
12961294

12971295
adjoint(FC::FactorComponent{Tv,:L}) where {Tv} = FactorComponent{Tv,:U}(FC.F)

src/solvers/spqr.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ struct QRSparseQ{Tv,Ti<:Integer} <: AbstractQ{Tv}
115115
end
116116

117117
Base.size(Q::QRSparseQ) = (size(Q.factors, 1), size(Q.factors, 1))
118-
Base.axes(Q::QRSparseQ) = map(Base.OneTo, size(Q))
119118

120119
Matrix{T}(Q::QRSparseQ) where {T} = lmul!(Q, Matrix{T}(I, size(Q, 1), min(size(Q, 1), Q.n)))
121120

@@ -154,7 +153,6 @@ function Base.size(F::QRSparse, i::Integer)
154153
throw(ArgumentError("second argument must be positive"))
155154
end
156155
end
157-
Base.axes(F::QRSparse) = map(Base.OneTo, size(F))
158156

159157
# From SPQR manual p. 6
160158
_default_tol(A::AbstractSparseMatrixCSC) =
@@ -518,8 +516,23 @@ function LinearAlgebra.ldiv!(X::StridedVecOrMat{T}, F::QRSparse{T}, B::StridedVe
518516
lmul!(adjoint(F.Q), W0)
519517

520518
# Solve R*X = Q'*P*B
521-
ldiv!(UpperTriangular(@view(F.R[Base.OneTo(rnk), Base.OneTo(rnk)])),
522-
@view(W0[Base.OneTo(rnk), :]))
519+
#
520+
# We call generic_trimatdiv! directly instead of going through
521+
# ldiv!(UpperTriangular(R_sub), ...) for two reasons:
522+
# 1. UpperTriangular requires a square matrix, but F.R is m×n
523+
# so we can only take a column view R[:, 1:rnk] (which is
524+
# m×rnk, not square). A row+column view R[1:rnk, 1:rnk]
525+
# would be square but doesn't match SparseMatrixCSCView,
526+
# causing dispatch to a slow generic fallback.
527+
# 2. generic_trimatdiv! is what UpperTriangular ldiv! dispatches
528+
# to anyway — calling it directly with uploc='U', isunitc='N',
529+
# tfun=identity is equivalent. The back-substitution loop
530+
# iterates over axes(B,1) = 1:rnk and searchsortedlast
531+
# excludes entries with row > j, so the extra rows in the
532+
# column view are never accessed.
533+
W_rnk = @view(W0[Base.OneTo(rnk), :])
534+
LinearAlgebra.generic_trimatdiv!(W_rnk, 'U', 'N', identity,
535+
@view(F.R[:, Base.OneTo(rnk)]), W_rnk)
523536

524537
# Apply right permutation: scatter solved rows into X using cpiv directly.
525538
# Zero X first so free variables (beyond rank) are zero in the basic solution.

src/sparsematrix.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,8 @@ Base.@propagate_inbounds nzrange(S::SparseMatrixCSCColumnSubset, col::Integer) =
313313
nzrange(S::UpperTriangular{<:Any,<:SparseMatrixCSCUnion}, i::Integer) = nzrangeup(S.data, i)
314314
nzrange(S::LowerTriangular{<:Any,<:SparseMatrixCSCUnion}, i::Integer) = nzrangelo(S.data, i)
315315

316+
indtype(S::SparseMatrixCSCColumnSubset{<:Any,Ti}) where {Ti} = Ti
317+
316318
const AbstractSparseMatrixCSCInclAdjointAndTranspose = Union{AbstractSparseMatrixCSC,Adjoint{<:Any,<:AbstractSparseMatrixCSC},Transpose{<:Any,<:AbstractSparseMatrixCSC}}
317319
function Base.isstored(A::AbstractSparseMatrixCSC, i::Integer, j::Integer)
318320
@boundscheck checkbounds(A, i, j)

test/linalg.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,20 @@ end
236236
end
237237
end
238238

239+
@testset "Dense times symmetric/Hermitian sparse matrix multiplication" begin
240+
A = [1 3; 2 4]
241+
As = sparse(A)
242+
B = [1 1; 1 1]
243+
@test mul!(copy(B), B, Hermitian(A), true, true) == mul!(copy(B), B, Hermitian(As), true, true)
244+
end
245+
246+
@testset "Column view of sparse matrix " begin
247+
S = sparse(1:4, 1:4, 1:4)
248+
Sv = @view S[:,3:4]
249+
@test Sv * sparse(ones(2)) == Sv*ones(2) == Matrix(Sv) * ones(2)
250+
@test Sv * sparse(ones(2,2)) == Sv*ones(2,2) == Matrix(Sv) * ones(2,2)
251+
end
252+
239253
@testset "in-place sparse-sparse mul!" begin
240254
for n in (20, 30)
241255
sA = sprandn(ComplexF64, n, n, 0.1); A = Array(sA)

test/linalg_solvers.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,20 +130,24 @@ end
130130
@testset "sparse matrix cond" begin
131131
Random.seed!(1235)
132132
local A = sparse(reshape([1.0], 1, 1))
133-
Ac = sprandn(20, 20,.5) + im*sprandn(20, 20,.5)
134-
Ar = sprandn(20, 20,.5) + eps()*I
135133
@test cond(A, 1) == 1.0
136-
# For a discussion of the tolerance, see #14778
137-
@test 0.99 <= cond(Ar, 1) \ opnorm(Ar, 1) * opnorm(inv(Array(Ar)), 1) < 3
138-
@test 0.99 <= cond(Ac, 1) \ opnorm(Ac, 1) * opnorm(inv(Array(Ac)), 1) < 3
139-
@test 0.99 <= cond(Ar, Inf) \ opnorm(Ar, Inf) * opnorm(inv(Array(Ar)), Inf) < 3
140-
@test 0.99 <= cond(Ac, Inf) \ opnorm(Ac, Inf) * opnorm(inv(Array(Ac)), Inf) < 3
141134
@test_throws ArgumentError cond(A,2)
142135
@test_throws ArgumentError cond(A,3)
143136
Arect = spzeros(10, 6)
144137
@test_throws DimensionMismatch cond(Arect, 1)
145138
@test_throws ArgumentError cond(Arect,2)
146139
@test_throws DimensionMismatch cond(Arect, Inf)
140+
Ac = sprandn(20, 20,.5) + im*sprandn(20, 20,.5)
141+
Ar = sprandn(20, 20,.5) + eps()*I
142+
# For a discussion of the tolerance, see #14778
143+
@test 0.99 <= cond(Ar, 1) \ opnorm(Ar, 1) * opnorm(inv(Array(Ar)), 1) < 3
144+
@test 0.99 <= cond(Ac, 1) \ opnorm(Ac, 1) * opnorm(inv(Array(Ac)), 1) < 3
145+
@test 0.99 <= cond(Ar, Inf) \ opnorm(Ar, Inf) * opnorm(inv(Array(Ar)), Inf) < 3
146+
@test 0.99 <= cond(Ac, Inf) \ opnorm(Ac, Inf) * opnorm(inv(Array(Ac)), Inf) < 3
147+
#issue 680
148+
A22 = sparse(randn(2,2))
149+
@test 0.99 cond(Array(A22), 1) / cond(A22, 1) < 3
150+
@test 0.99 cond(Array(A22), Inf) / cond(A22, Inf) < 3
147151
end
148152

149153
@testset "sparse matrix opnormestinv" begin
@@ -159,6 +163,9 @@ end
159163
@test_throws ArgumentError SparseArrays.opnormestinv(Ac,0)
160164
@test_throws ArgumentError SparseArrays.opnormestinv(Ac,21)
161165
@test_throws DimensionMismatch SparseArrays.opnormestinv(sprand(3,5,.9))
166+
#issue 680
167+
A33 = sparse(randn(3,3))
168+
@test SparseArrays.opnormestinv(A33,3) opnorm(inv(Array(A33)),1) atol=1e-4
162169
end
163170

164171
@testset "factorization" begin

0 commit comments

Comments
 (0)