@@ -4,6 +4,41 @@ using LinearAlgebra: AbstractTriangular, StridedMaybeAdjOrTransMat, UpperOrLower
44 RealHermSymComplexHerm, HermOrSym, checksquare, sym_uplo, wrap
55using Random: rand!
66
7+ _fix_size (M, nrow, ncol) = M
8+
9+ @static if VERSION >= v " 1.11"
10+ # An immutable fixed size wrapper for matrices to work around
11+ # the performance issue caused by https://github.com/JuliaLang/julia/issues/60409
12+ struct _FixedSizeMatrix{Trans,R}
13+ ref:: R
14+ nrow:: Int
15+ ncol:: Int
16+ function _FixedSizeMatrix {Trans} (ref:: R , nrow, ncol) where {Trans,R}
17+ new {Trans,R} (ref, nrow, ncol)
18+ end
19+ end
20+ @inline Base. getindex (A:: _FixedSizeMatrix{'N'} , i, j) =
21+ @inbounds Core. memoryrefnew (A. ref, A. nrow * (j - 1 ) + i, false )[]
22+ @inline Base. setindex! (A:: _FixedSizeMatrix{'N'} , v, i, j) =
23+ @inbounds Core. memoryrefnew (A. ref, A. nrow * (j - 1 ) + i, false )[] = v
24+
25+ @inline Base. getindex (A:: _FixedSizeMatrix{'T'} , i, j) =
26+ @inbounds transpose (Core. memoryrefnew (A. ref, A. ncol * (i - 1 ) + j, false )[])
27+ @inline Base. setindex! (A:: _FixedSizeMatrix{'T'} , v, i, j) =
28+ @inbounds Core. memoryrefnew (A. ref, A. ncol * (i - 1 ) + j, false )[] = transpose (v)
29+
30+ @inline Base. getindex (A:: _FixedSizeMatrix{'C'} , i, j) =
31+ @inbounds adjoint (Core. memoryrefnew (A. ref, A. ncol * (i - 1 ) + j, false )[])
32+ @inline Base. setindex! (A:: _FixedSizeMatrix{'C'} , v, i, j) =
33+ @inbounds Core. memoryrefnew (A. ref, A. ncol * (i - 1 ) + j, false )[] = adjoint (v)
34+
35+ @inline _fix_size (A:: Matrix , nrow, ncol) = _FixedSizeMatrix {'N'} (A. ref, nrow, ncol)
36+ @inline _fix_size (A:: Transpose{<:Any,<:Matrix} , nrow, ncol) =
37+ _FixedSizeMatrix {'T'} (A. parent. ref, nrow, ncol)
38+ @inline _fix_size (A:: Adjoint{<:Any,<:Matrix} , nrow, ncol) =
39+ _FixedSizeMatrix {'C'} (A. parent. ref, nrow, ncol)
40+ end
41+
742const tilebufsize = 10800 # Approximately 32k/3
843
944# In matrix-vector multiplication, the correct orientation of the vector is assumed.
@@ -120,13 +155,15 @@ end
120155function _spmatmul! (C, A, B, α, β)
121156 Cax2 = axes (C, 2 )
122157 Aax2 = axes (A, 2 )
123- _matmul_size_AB (C, A, B)
158+ mC, nC, mA, nA, mB, nB = _matmul_size_AB (C, A, B)
124159 nzv = nonzeros (A)
125160 rv = rowvals (A)
126161 isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
127162 if α isa Bool && ! α
128163 return
129164 end
165+ B = _fix_size (B, mB, nB)
166+ C = _fix_size (C, mC, nC)
130167 for k in Cax2
131168 @inbounds for col in Aax2
132169 αxj = α isa Bool ? B[col,k] : B[col,k] * α
@@ -140,13 +177,15 @@ end
140177function _At_or_Ac_mul_B! (tfun:: Function , C, A, B, α, β)
141178 Cax2 = axes (C, 2 )
142179 Aax2 = axes (A, 2 )
143- _matmul_size_AtB (C, A, B)
180+ mC, nC, mA, nA, mB, nB = _matmul_size_AtB (C, A, B)
144181 nzv = nonzeros (A)
145182 rv = rowvals (A)
146183 isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
147184 if α isa Bool && ! α
148185 return
149186 end
187+ B = _fix_size (B, mB, nB)
188+ C = _fix_size (C, mC, nC)
150189 for k in Cax2
151190 @inbounds for col in Aax2
152191 tmp = zero (eltype (C))
@@ -172,13 +211,15 @@ end
172211function _spmul! (C:: StridedMatrix , X:: DenseMatrixUnion , A:: SparseMatrixCSCUnion2 , α:: Number , β:: Number )
173212 Aax2 = axes (A, 2 )
174213 Xax1 = axes (X, 1 )
175- _matmul_size_AB (C, X, A)
214+ mC, nC, mX, nX, mA, nA = _matmul_size_AB (C, X, A)
176215 rv = rowvals (A)
177216 nzv = nonzeros (A)
178217 isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
179218 if α isa Bool && ! α
180219 return
181220 end
221+ C = _fix_size (C, mC, nC)
222+ X = _fix_size (X, mX, nX)
182223 @inbounds for col in Aax2, k in nzrange (A, col)
183224 Aiα = α isa Bool ? nzv[k] : nzv[k] * α
184225 rvk = rv[k]
@@ -190,13 +231,15 @@ end
190231function _spmul! (C:: StridedMatrix , X:: AdjOrTrans{<:Any,<:DenseMatrixUnion} , A:: SparseMatrixCSCUnion2 , α:: Number , β:: Number )
191232 Xax1 = axes (X, 1 )
192233 Cax2 = axes (C, 2 )
193- _matmul_size_AB (C, X, A)
234+ mC, nC, mX, nX, mA, nA = _matmul_size_AB (C, X, A)
194235 rv = rowvals (A)
195236 nzv = nonzeros (A)
196237 isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
197238 if α isa Bool && ! α
198239 return
199240 end
241+ C = _fix_size (C, mC, nC)
242+ X = _fix_size (X, mX, nX)
200243 for multivec_row in Xax1, col in Cax2
201244 @inbounds for k in nzrange (A, col)
202245 C[multivec_row, col] +=
@@ -209,13 +252,15 @@ end
209252function _A_mul_Bt_or_Bc! (tfun:: Function , C:: StridedMatrix , A:: AbstractMatrix , B:: SparseMatrixCSCUnion2 , α:: Number , β:: Number )
210253 Bax2 = axes (B, 2 )
211254 Aax1 = axes (A, 1 )
212- _matmul_size_ABt (C, A, B)
255+ mC, nC, mA, nA, mB, nB = _matmul_size_ABt (C, A, B)
213256 rv = rowvals (B)
214257 nzv = nonzeros (B)
215258 isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
216259 if α isa Bool && ! α
217260 return
218261 end
262+ C = _fix_size (C, mC, nC)
263+ A = _fix_size (A, mA, nA)
219264 @inbounds for col in Bax2, k in nzrange (B, col)
220265 Biα = α isa Bool ? tfun (nzv[k]) : tfun (nzv[k]) * α
221266 rvk = rv[k]
0 commit comments