Skip to content

Commit 7692db0

Browse files
committed
Use a fixed size wrapper to workaround julia 1.11+ bug
1 parent 8d7ab9c commit 7692db0

1 file changed

Lines changed: 50 additions & 5 deletions

File tree

src/linalg.jl

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,41 @@ using LinearAlgebra: AbstractTriangular, StridedMaybeAdjOrTransMat, UpperOrLower
44
RealHermSymComplexHerm, HermOrSym, checksquare, sym_uplo, wrap
55
using Random: rand!
66

7+
_fix_size(M, nrow, ncol) = M
8+
9+
@static if VERSION >= v"1.11"
10+
# An immutable fixed size wrapper for matrices to work around
11+
# the performance issue caused by https://github.com/JuliaLang/julia/issues/60409
12+
struct _FixedSizeMatrix{Trans,R}
13+
ref::R
14+
nrow::Int
15+
ncol::Int
16+
function _FixedSizeMatrix{Trans}(ref::R, nrow, ncol) where {Trans,R}
17+
new{Trans,R}(ref, nrow, ncol)
18+
end
19+
end
20+
@inline Base.getindex(A::_FixedSizeMatrix{'N'}, i, j) =
21+
@inbounds Core.memoryrefnew(A.ref, A.nrow * (j - 1) + i, false)[]
22+
@inline Base.setindex!(A::_FixedSizeMatrix{'N'}, v, i, j) =
23+
@inbounds Core.memoryrefnew(A.ref, A.nrow * (j - 1) + i, false)[] = v
24+
25+
@inline Base.getindex(A::_FixedSizeMatrix{'T'}, i, j) =
26+
@inbounds transpose(Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[])
27+
@inline Base.setindex!(A::_FixedSizeMatrix{'T'}, v, i, j) =
28+
@inbounds Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[] = transpose(v)
29+
30+
@inline Base.getindex(A::_FixedSizeMatrix{'C'}, i, j) =
31+
@inbounds adjoint(Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[])
32+
@inline Base.setindex!(A::_FixedSizeMatrix{'C'}, v, i, j) =
33+
@inbounds Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[] = adjoint(v)
34+
35+
@inline _fix_size(A::Matrix, nrow, ncol) = _FixedSizeMatrix{'N'}(A.ref, nrow, ncol)
36+
@inline _fix_size(A::Transpose{<:Any,<:Matrix}, nrow, ncol) =
37+
_FixedSizeMatrix{'T'}(A.parent.ref, nrow, ncol)
38+
@inline _fix_size(A::Adjoint{<:Any,<:Matrix}, nrow, ncol) =
39+
_FixedSizeMatrix{'C'}(A.parent.ref, nrow, ncol)
40+
end
41+
742
const tilebufsize = 10800 # Approximately 32k/3
843

944
# In matrix-vector multiplication, the correct orientation of the vector is assumed.
@@ -120,13 +155,15 @@ end
120155
function _spmatmul!(C, A, B, α, β)
121156
Cax2 = axes(C, 2)
122157
Aax2 = axes(A, 2)
123-
_matmul_size_AB(C, A, B)
158+
mC, nC, mA, nA, mB, nB = _matmul_size_AB(C, A, B)
124159
nzv = nonzeros(A)
125160
rv = rowvals(A)
126161
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
127162
if α isa Bool && !α
128163
return
129164
end
165+
B = _fix_size(B, mB, nB)
166+
C = _fix_size(C, mC, nC)
130167
for k in Cax2
131168
@inbounds for col in Aax2
132169
αxj = α isa Bool ? B[col,k] : B[col,k] * α
@@ -140,13 +177,15 @@ end
140177
function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β)
141178
Cax2 = axes(C, 2)
142179
Aax2 = axes(A, 2)
143-
_matmul_size_AtB(C, A, B)
180+
mC, nC, mA, nA, mB, nB = _matmul_size_AtB(C, A, B)
144181
nzv = nonzeros(A)
145182
rv = rowvals(A)
146183
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
147184
if α isa Bool && !α
148185
return
149186
end
187+
B = _fix_size(B, mB, nB)
188+
C = _fix_size(C, mC, nC)
150189
for k in Cax2
151190
@inbounds for col in Aax2
152191
tmp = zero(eltype(C))
@@ -172,13 +211,15 @@ end
172211
function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2, α::Number, β::Number)
173212
Aax2 = axes(A, 2)
174213
Xax1 = axes(X, 1)
175-
_matmul_size_AB(C, X, A)
214+
mC, nC, mX, nX, mA, nA = _matmul_size_AB(C, X, A)
176215
rv = rowvals(A)
177216
nzv = nonzeros(A)
178217
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
179218
if α isa Bool && !α
180219
return
181220
end
221+
C = _fix_size(C, mC, nC)
222+
X = _fix_size(X, mX, nX)
182223
@inbounds for col in Aax2, k in nzrange(A, col)
183224
Aiα = α isa Bool ? nzv[k] : nzv[k] * α
184225
rvk = rv[k]
@@ -190,13 +231,15 @@ end
190231
function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::SparseMatrixCSCUnion2, α::Number, β::Number)
191232
Xax1 = axes(X, 1)
192233
Cax2 = axes(C, 2)
193-
_matmul_size_AB(C, X, A)
234+
mC, nC, mX, nX, mA, nA = _matmul_size_AB(C, X, A)
194235
rv = rowvals(A)
195236
nzv = nonzeros(A)
196237
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
197238
if α isa Bool && !α
198239
return
199240
end
241+
C = _fix_size(C, mC, nC)
242+
X = _fix_size(X, mX, nX)
200243
for multivec_row in Xax1, col in Cax2
201244
@inbounds for k in nzrange(A, col)
202245
C[multivec_row, col] +=
@@ -209,13 +252,15 @@ end
209252
function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::SparseMatrixCSCUnion2, α::Number, β::Number)
210253
Bax2 = axes(B, 2)
211254
Aax1 = axes(A, 1)
212-
_matmul_size_ABt(C, A, B)
255+
mC, nC, mA, nA, mB, nB = _matmul_size_ABt(C, A, B)
213256
rv = rowvals(B)
214257
nzv = nonzeros(B)
215258
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
216259
if α isa Bool && !α
217260
return
218261
end
262+
C = _fix_size(C, mC, nC)
263+
A = _fix_size(A, mA, nA)
219264
@inbounds for col in Bax2, k in nzrange(B, col)
220265
Biα = α isa Bool ? tfun(nzv[k]) : tfun(nzv[k]) * α
221266
rvk = rv[k]

0 commit comments

Comments
 (0)