Skip to content

Commit e3f6255

Browse files
committed
Use muladd when possible in sparse matrix multiplication
1 parent 4cba083 commit e3f6255

1 file changed

Lines changed: 19 additions & 11 deletions

File tree

src/linalg.jl

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ function _spmatmul!(C, A, B, α, β)
169169
@inbounds for col in Aax2
170170
αxj = α isa Bool ? B[col,k] : B[col,k] * α
171171
for j in nzrange(A, col)
172-
C[rv[j], k] += nzv[j]*αxj
172+
rvj = rv[j]
173+
C[rvj, k] = muladd(nzv[j], αxj, C[rvj, k])
173174
end
174175
end
175176
end
@@ -185,15 +186,16 @@ function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β)
185186
if α isa Bool && !α
186187
return
187188
end
189+
C0 = zero(eltype(C)) # Pre-allocate for BigFloat/BigInt etc
188190
B = _fix_size(B, mB, nB)
189191
C = _fix_size(C, mC, nC)
190192
for k in Cax2
191193
@inbounds for col in Aax2
192-
tmp = zero(eltype(C))
194+
tmp = C0
193195
for j in nzrange(A, col)
194-
tmp += tfun(nzv[j])*B[rv[j],k]
196+
tmp = muladd(tfun(nzv[j]), B[rv[j], k], tmp)
195197
end
196-
C[col,k] += α isa Bool ? tmp : tmp * α
198+
C[col, k] = α isa Bool ? tmp + C[col, k] : muladd(tmp, α, C[col, k])
197199
end
198200
end
199201
end
@@ -225,7 +227,8 @@ function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2
225227
Aiα = α isa Bool ? nzv[k] : nzv[k] * α
226228
rvk = rv[k]
227229
@simd for multivec_row in Xax1
228-
C[multivec_row, col] += X[multivec_row, rvk] * Aiα
230+
C[multivec_row, col] = muladd(X[multivec_row, rvk], Aiα,
231+
C[multivec_row, col])
229232
end
230233
end
231234
end
@@ -241,12 +244,17 @@ function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::S
241244
end
242245
C = _fix_size(C, mC, nC)
243246
X = _fix_size(X, mX, nX)
244-
for multivec_row in Xax1, col in Cax2
245-
@inbounds for k in nzrange(A, col)
246-
C[multivec_row, col] +=
247-
isa Bool ? X[multivec_row, rv[k]] * nzv[k] :
248-
X[multivec_row, rv[k]] * nzv[k] * α)
247+
@inbounds for multivec_row in Xax1, col in Cax2
248+
nzrng = nzrange(A, col)
249+
if isempty(nzrng)
250+
continue
249251
end
252+
tmp = C[multivec_row, col]
253+
for k in nzrng
254+
tmp = muladd(X[multivec_row, rv[k]],
255+
isa Bool ? nzv[k] : nzv[k] * α), tmp)
256+
end
257+
C[multivec_row, col] = tmp
250258
end
251259
end
252260

@@ -266,7 +274,7 @@ function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B
266274
Biα = α isa Bool ? tfun(nzv[k]) : tfun(nzv[k]) * α
267275
rvk = rv[k]
268276
@simd for multivec_col in Aax1
269-
C[multivec_col, rvk] += A[multivec_col, col] * Biα
277+
C[multivec_col, rvk] = muladd(A[multivec_col, col], Biα, C[multivec_col, rvk])
270278
end
271279
end
272280
end

0 commit comments

Comments
 (0)