Skip to content

Commit 5862843

Browse files
committed
Pre-compute matrix axes out of the loop
1 parent 715dd78 commit 5862843

1 file changed

Lines changed: 19 additions & 9 deletions

File tree

src/linalg.jl

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,14 @@ end
118118
@inline _matmul_size_ABt(C, A, B) = _matmul_size(C, A, B, Val('N'), Val('T'))
119119

120120
function _spmatmul!(C, A, B, α, β)
121+
Cax2 = axes(C, 2)
122+
Aax2 = axes(A, 2)
121123
_matmul_size_AB(C, A, B)
122124
nzv = nonzeros(A)
123125
rv = rowvals(A)
124126
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
125-
for k in axes(C, 2)
126-
@inbounds for col in axes(A,2)
127+
for k in Cax2
128+
@inbounds for col in Aax2
127129
αxj = B[col,k] * α
128130
for j in nzrange(A, col)
129131
C[rv[j], k] += nzv[j]*αxj
@@ -133,12 +135,14 @@ function _spmatmul!(C, A, B, α, β)
133135
end
134136

135137
function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β)
138+
Cax2 = axes(C, 2)
139+
Aax2 = axes(A, 2)
136140
_matmul_size_AtB(C, A, B)
137141
nzv = nonzeros(A)
138142
rv = rowvals(A)
139143
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
140-
for k in axes(C, 2)
141-
@inbounds for col in axes(A,2)
144+
for k in Cax2
145+
@inbounds for col in Aax2
142146
tmp = zero(eltype(C))
143147
for j in nzrange(A, col)
144148
tmp += tfun(nzv[j])*B[rv[j],k]
@@ -160,39 +164,45 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB
160164
return C
161165
end
162166
function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2, α::Number, β::Number)
167+
Aax2 = axes(A, 2)
168+
Xax1 = axes(X, 1)
163169
_matmul_size_AB(C, X, A)
164170
rv = rowvals(A)
165171
nzv = nonzeros(A)
166172
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
167-
@inbounds for col in axes(A,2), k in nzrange(A, col)
173+
@inbounds for col in Aax2, k in nzrange(A, col)
168174
Aiα = nzv[k] * α
169175
rvk = rv[k]
170-
@simd for multivec_row in axes(X,1)
176+
@simd for multivec_row in Xax1
171177
C[multivec_row, col] += X[multivec_row, rvk] * Aiα
172178
end
173179
end
174180
end
175181
function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::SparseMatrixCSCUnion2, α::Number, β::Number)
182+
Xax1 = axes(X, 1)
183+
Cax2 = axes(C, 2)
176184
_matmul_size_AB(C, X, A)
177185
rv = rowvals(A)
178186
nzv = nonzeros(A)
179187
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
180-
for multivec_row in axes(X,1), col in axes(C, 2)
188+
for multivec_row in Xax1, col in Cax2
181189
@inbounds for k in nzrange(A, col)
182190
C[multivec_row, col] += X[multivec_row, rv[k]] * nzv[k] * α
183191
end
184192
end
185193
end
186194

187195
function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::SparseMatrixCSCUnion2, α::Number, β::Number)
196+
Bax2 = axes(B, 2)
197+
Aax1 = axes(A, 1)
188198
_matmul_size_ABt(C, A, B)
189199
rv = rowvals(B)
190200
nzv = nonzeros(B)
191201
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
192-
@inbounds for col in axes(B, 2), k in nzrange(B, col)
202+
@inbounds for col in Bax2, k in nzrange(B, col)
193203
Biα = tfun(nzv[k]) * α
194204
rvk = rv[k]
195-
@simd for multivec_col in axes(A,1)
205+
@simd for multivec_col in Aax1
196206
C[multivec_col, rvk] += A[multivec_col, col] * Biα
197207
end
198208
end

0 commit comments

Comments
 (0)