@@ -4,6 +4,41 @@ using LinearAlgebra: AbstractTriangular, StridedMaybeAdjOrTransMat, UpperOrLower
44 RealHermSymComplexHerm, HermOrSym, checksquare, sym_uplo, wrap
55using Random: rand!
66
7+ _fix_size (M, nrow, ncol) = M
8+
9+ # An immutable fixed size wrapper for matrices to work around
10+ # the performance issue caused by https://github.com/JuliaLang/julia/issues/60409
11+ # This is more-of-less a stripped down version of FixedSizeArrays
12+ # which we can't easily use without pulling that into the standard library.
13+ struct _FixedSizeMatrix{Trans,R}
14+ ref:: R
15+ nrow:: Int
16+ ncol:: Int
17+ function _FixedSizeMatrix {Trans} (ref:: R , nrow, ncol) where {Trans,R}
18+ new {Trans,R} (ref, nrow, ncol)
19+ end
20+ end
21+ @inline Base. getindex (A:: _FixedSizeMatrix{'N'} , i, j) =
22+ @inbounds Core. memoryrefnew (A. ref, A. nrow * (j - 1 ) + i, false )[]
23+ @inline Base. setindex! (A:: _FixedSizeMatrix{'N'} , v, i, j) =
24+ @inbounds Core. memoryrefnew (A. ref, A. nrow * (j - 1 ) + i, false )[] = v
25+
26+ @inline Base. getindex (A:: _FixedSizeMatrix{'T'} , i, j) =
27+ @inbounds transpose (Core. memoryrefnew (A. ref, A. ncol * (i - 1 ) + j, false )[])
28+ @inline Base. setindex! (A:: _FixedSizeMatrix{'T'} , v, i, j) =
29+ @inbounds Core. memoryrefnew (A. ref, A. ncol * (i - 1 ) + j, false )[] = transpose (v)
30+
31+ @inline Base. getindex (A:: _FixedSizeMatrix{'C'} , i, j) =
32+ @inbounds adjoint (Core. memoryrefnew (A. ref, A. ncol * (i - 1 ) + j, false )[])
33+ @inline Base. setindex! (A:: _FixedSizeMatrix{'C'} , v, i, j) =
34+ @inbounds Core. memoryrefnew (A. ref, A. ncol * (i - 1 ) + j, false )[] = adjoint (v)
35+
36+ @inline _fix_size (A:: Matrix , nrow, ncol) = _FixedSizeMatrix {'N'} (A. ref, nrow, ncol)
37+ @inline _fix_size (A:: Transpose{<:Any,<:Matrix} , nrow, ncol) =
38+ _FixedSizeMatrix {'T'} (A. parent. ref, nrow, ncol)
39+ @inline _fix_size (A:: Adjoint{<:Any,<:Matrix} , nrow, ncol) =
40+ _FixedSizeMatrix {'C'} (A. parent. ref, nrow, ncol)
41+
742const tilebufsize = 10800 # Approximately 32k/3
843
944# In matrix-vector multiplication, the correct orientation of the vector is assumed.
@@ -64,52 +99,99 @@ Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, alpha, beta)
6499 T = eltype (C)
65100 _mul! (rangefun, diagop, odiagop, C, A, wrap (B, tB), T (alpha), T (beta))
66101 else
67- @stable_muladdmul LinearAlgebra. _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), MulAddMul ( alpha, beta) )
102+ LinearAlgebra. _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), alpha, beta)
68103 end
69104 return C
70105end
71106
107+ # Slow non-inlined functions for throwing the error without messing up the caller
108+ @noinline function _matmul_size_error (mC, nC, mA, nA, mB, nB, At, Bt)
109+ if At == ' N'
110+ Anames = " first" , " second"
111+ else
112+ Anames = " second" , " first"
113+ end
114+ if Bt == ' N'
115+ Bnames = " first" , " second"
116+ else
117+ Bnames = " second" , " first"
118+ end
119+ nA == mB ||
120+ throw (DimensionMismatch (" $(Anames[2 ]) dimension of A, $nA , does not match the $(Bnames[1 ]) dimension of B, $mB " ))
121+ mA == mC ||
122+ throw (DimensionMismatch (" $(Anames[1 ]) dimension of A, $mA , does not match the first dimension of C, $mC " ))
123+ nB == nC ||
124+ throw (DimensionMismatch (" $(Bnames[2 ]) dimension of B, $nB , does not match the second dimension of C, $nC " ))
125+ # unreachable
126+ throw (DimensionMismatch (" Unknown dimension mismatch" ))
127+ end
128+
129+ @inline function _matmul_size (C, A, B, :: Val{At} , :: Val{Bt} ) where {At,Bt}
130+ mC = size (C, 1 )
131+ nC = size (C, 2 )
132+ mA = size (A, 1 )
133+ nA = size (A, 2 )
134+ mB = size (B, 1 )
135+ nB = size (B, 2 )
136+
137+ _mA, _nA = At == ' N' ? (mA, nA) : (nA, mA)
138+ _mB, _nB = Bt == ' N' ? (mB, nB) : (nB, mB)
139+
140+ if (_nA != _mB) | (_mA != mC) | (_nB != nC)
141+ _matmul_size_error (mC, nC, _mA, _nA, _mB, _nB, At, Bt)
142+ end
143+ return mC, nC, mA, nA, mB, nB
144+ end
145+
146+ @inline _matmul_size_AB (C, A, B) = _matmul_size (C, A, B, Val (' N' ), Val (' N' ))
147+ @inline _matmul_size_AtB (C, A, B) = _matmul_size (C, A, B, Val (' T' ), Val (' N' ))
148+ @inline _matmul_size_ABt (C, A, B) = _matmul_size (C, A, B, Val (' N' ), Val (' T' ))
149+
72150function _spmatmul! (C, A, B, α, β)
73- size (A, 2 ) == size (B, 1 ) ||
74- throw (DimensionMismatch (" second dimension of A, $(size (A,2 )) , does not match the first dimension of B, $(size (B,1 )) " ))
75- size (A, 1 ) == size (C, 1 ) ||
76- throw (DimensionMismatch (" first dimension of A, $(size (A,1 )) , does not match the first dimension of C, $(size (C,1 )) " ))
77- size (B, 2 ) == size (C, 2 ) ||
78- throw (DimensionMismatch (" second dimension of B, $(size (B,2 )) , does not match the second dimension of C, $(size (C,2 )) " ))
151+ Cax2 = axes (C, 2 )
152+ Aax2 = axes (A, 2 )
153+ mC, nC, mA, nA, mB, nB = _matmul_size_AB (C, A, B)
79154 nzv = nonzeros (A)
80155 rv = rowvals (A)
81- β != one (β) && LinearAlgebra. _rmul_or_fill! (C, β)
82- for k in axes (C, 2 )
83- @inbounds for col in axes (A,2 )
84- αxj = B[col,k] * α
156+ isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
157+ if α isa Bool && ! α
158+ return
159+ end
160+ B = _fix_size (B, mB, nB)
161+ C = _fix_size (C, mC, nC)
162+ for k in Cax2
163+ @inbounds for col in Aax2
164+ αxj = α isa Bool ? B[col,k] : B[col,k] * α
85165 for j in nzrange (A, col)
86- C[rv[j], k] += nzv[j]* αxj
166+ rvj = rv[j]
167+ C[rvj, k] = muladd (nzv[j], αxj, C[rvj, k])
87168 end
88169 end
89170 end
90- C
91171end
92172
93173function _At_or_Ac_mul_B! (tfun:: Function , C, A, B, α, β)
94- size (A, 2 ) == size (C, 1 ) ||
95- throw (DimensionMismatch (" second dimension of A, $(size (A,2 )) , does not match the first dimension of C, $(size (C,1 )) " ))
96- size (A, 1 ) == size (B, 1 ) ||
97- throw (DimensionMismatch (" first dimension of A, $(size (A,1 )) , does not match the first dimension of B, $(size (B,1 )) " ))
98- size (B, 2 ) == size (C, 2 ) ||
99- throw (DimensionMismatch (" second dimension of B, $(size (B,2 )) , does not match the second dimension of C, $(size (C,2 )) " ))
174+ Cax2 = axes (C, 2 )
175+ Aax2 = axes (A, 2 )
176+ mC, nC, mA, nA, mB, nB = _matmul_size_AtB (C, A, B)
100177 nzv = nonzeros (A)
101178 rv = rowvals (A)
102- β != one (β) && LinearAlgebra. _rmul_or_fill! (C, β)
103- for k in axes (C, 2 )
104- @inbounds for col in axes (A,2 )
105- tmp = zero (eltype (C))
179+ isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
180+ if α isa Bool && ! α
181+ return
182+ end
183+ C0 = zero (eltype (C)) # Pre-allocate for BigFloat/BigInt etc
184+ B = _fix_size (B, mB, nB)
185+ C = _fix_size (C, mC, nC)
186+ for k in Cax2
187+ @inbounds for col in Aax2
188+ tmp = C0
106189 for j in nzrange (A, col)
107- tmp += tfun (nzv[j])* B[rv[j],k]
190+ tmp = muladd ( tfun (nzv[j]), B[rv[j], k], tmp)
108191 end
109- C[col,k] += tmp * α
192+ C[col, k] = α isa Bool ? tmp + C[col, k] : muladd (tmp, α, C[col, k])
110193 end
111194 end
112- C
113195end
114196
115197Base. @constprop :aggressive function generic_matmatmul_wrapper! (C:: StridedMatrix , tA, tB, A:: DenseMatrixUnion , B:: SparseMatrixCSCUnion2 , alpha:: Number , beta:: Number , :: LinearAlgebra.BlasFlag.SyrkHerkGemm )
@@ -127,63 +209,71 @@ Base.@constprop :aggressive generic_matmatmul_wrapper!(C::StridedMatrix, tA, tB,
127209 LinearAlgebra. _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), alpha, beta)
128210
129211function _spmul! (C:: StridedMatrix , X:: DenseMatrixUnion , A:: SparseMatrixCSCUnion2 , α:: Number , β:: Number )
130- mX, nX = size (X)
131- nX == size (A, 1 ) ||
132- throw (DimensionMismatch (" second dimension of X, $nX , does not match the first dimension of A, $(size (A,1 )) " ))
133- mX == size (C, 1 ) ||
134- throw (DimensionMismatch (" first dimension of X, $mX , does not match the first dimension of C, $(size (C,1 )) " ))
135- size (A, 2 ) == size (C, 2 ) ||
136- throw (DimensionMismatch (" second dimension of A, $(size (A,2 )) , does not match the second dimension of C, $(size (C,2 )) " ))
212+ Aax2 = axes (A, 2 )
213+ Xax1 = axes (X, 1 )
214+ mC, nC, mX, nX, mA, nA = _matmul_size_AB (C, X, A)
137215 rv = rowvals (A)
138216 nzv = nonzeros (A)
139- β != one (β) && LinearAlgebra. _rmul_or_fill! (C, β)
140- @inbounds for col in axes (A,2 ), k in nzrange (A, col)
141- Aiα = nzv[k] * α
217+ isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
218+ if α isa Bool && ! α
219+ return
220+ end
221+ C = _fix_size (C, mC, nC)
222+ X = _fix_size (X, mX, nX)
223+ @inbounds for col in Aax2, k in nzrange (A, col)
224+ Aiα = α isa Bool ? nzv[k] : nzv[k] * α
142225 rvk = rv[k]
143- @simd for multivec_row in axes (X,1 )
144- C[multivec_row, col] += X[multivec_row, rvk] * Aiα
226+ @simd for multivec_row in Xax1
227+ C[multivec_row, col] = muladd (X[multivec_row, rvk], Aiα,
228+ C[multivec_row, col])
145229 end
146230 end
147- C
148231end
149232function _spmul! (C:: StridedMatrix , X:: AdjOrTrans{<:Any,<:DenseMatrixUnion} , A:: SparseMatrixCSCUnion2 , α:: Number , β:: Number )
150- mX, nX = size (X)
151- nX == size (A, 1 ) ||
152- throw (DimensionMismatch (" second dimension of X, $nX , does not match the first dimension of A, $(size (A,1 )) " ))
153- mX == size (C, 1 ) ||
154- throw (DimensionMismatch (" first dimension of X, $mX , does not match the first dimension of C, $(size (C,1 )) " ))
155- size (A, 2 ) == size (C, 2 ) ||
156- throw (DimensionMismatch (" second dimension of A, $(size (A,2 )) , does not match the second dimension of C, $(size (C,2 )) " ))
233+ Xax1 = axes (X, 1 )
234+ Cax2 = axes (C, 2 )
235+ mC, nC, mX, nX, mA, nA = _matmul_size_AB (C, X, A)
157236 rv = rowvals (A)
158237 nzv = nonzeros (A)
159- β != one (β) && LinearAlgebra. _rmul_or_fill! (C, β)
160- for multivec_row in axes (X,1 ), col in axes (C, 2 )
161- @inbounds for k in nzrange (A, col)
162- C[multivec_row, col] += X[multivec_row, rv[k]] * nzv[k] * α
238+ isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
239+ if α isa Bool && ! α
240+ return
241+ end
242+ C = _fix_size (C, mC, nC)
243+ X = _fix_size (X, mX, nX)
244+ @inbounds for multivec_row in Xax1, col in Cax2
245+ nzrng = nzrange (A, col)
246+ if isempty (nzrng)
247+ continue
248+ end
249+ tmp = C[multivec_row, col]
250+ for k in nzrng
251+ tmp = muladd (X[multivec_row, rv[k]],
252+ (α isa Bool ? nzv[k] : nzv[k] * α), tmp)
163253 end
254+ C[multivec_row, col] = tmp
164255 end
165- C
166256end
167257
168258function _A_mul_Bt_or_Bc! (tfun:: Function , C:: StridedMatrix , A:: AbstractMatrix , B:: SparseMatrixCSCUnion2 , α:: Number , β:: Number )
169- mA, nA = size (A)
170- nA == size (B, 2 ) ||
171- throw (DimensionMismatch (" second dimension of A, $nA , does not match the second dimension of B, $(size (B,2 )) " ))
172- mA == size (C, 1 ) ||
173- throw (DimensionMismatch (" first dimension of A, $mA , does not match the first dimension of C, $(size (C,1 )) " ))
174- size (B, 1 ) == size (C, 2 ) ||
175- throw (DimensionMismatch (" first dimension of B, $(size (B,2 )) , does not match the second dimension of C, $(size (C,2 )) " ))
259+ Bax2 = axes (B, 2 )
260+ Aax1 = axes (A, 1 )
261+ mC, nC, mA, nA, mB, nB = _matmul_size_ABt (C, A, B)
176262 rv = rowvals (B)
177263 nzv = nonzeros (B)
178- β != one (β) && LinearAlgebra. _rmul_or_fill! (C, β)
179- @inbounds for col in axes (B, 2 ), k in nzrange (B, col)
180- Biα = tfun (nzv[k]) * α
264+ isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
265+ if α isa Bool && ! α
266+ return
267+ end
268+ C = _fix_size (C, mC, nC)
269+ A = _fix_size (A, mA, nA)
270+ @inbounds for col in Bax2, k in nzrange (B, col)
271+ Biα = α isa Bool ? tfun (nzv[k]) : tfun (nzv[k]) * α
181272 rvk = rv[k]
182- @simd for multivec_col in axes (A, 1 )
183- C[multivec_col, rvk] += A[multivec_col, col] * Biα
273+ @simd for multivec_col in Aax1
274+ C[multivec_col, rvk] = muladd ( A[multivec_col, col], Biα, C[multivec_col, rvk])
184275 end
185276 end
186- C
187277end
188278
189279function * (A:: Diagonal , b:: AbstractSparseVector )
@@ -1238,7 +1328,7 @@ function _mul!(nzrang::Function, diagop::Function, odiagop::Function, C::Strided
12381328 rv = rowvals (A)
12391329 nzv = nonzeros (A)
12401330 let z = T (0 ), sumcol= z, αxj= z, aarc= z, α = α
1241- β != one (β) && LinearAlgebra. _rmul_or_fill! (C, β)
1331+ isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
12421332 @inbounds for k in axes (B,2 )
12431333 for col in axes (B,1 )
12441334 αxj = B[col,k] * α
@@ -1257,7 +1347,6 @@ function _mul!(nzrang::Function, diagop::Function, odiagop::Function, C::Strided
12571347 end
12581348 end
12591349 end
1260- C
12611350end
12621351
12631352# row range up to (and including if excl=false) diagonal
0 commit comments