Skip to content

Commit 867815f

Browse files
committed
Use muladd when possible in sparse matrix multiplication
1 parent 7692db0 commit 867815f

1 file changed

Lines changed: 19 additions & 11 deletions

File tree

src/linalg.jl

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ function _spmatmul!(C, A, B, α, β)
168168
@inbounds for col in Aax2
169169
αxj = α isa Bool ? B[col,k] : B[col,k] * α
170170
for j in nzrange(A, col)
171-
C[rv[j], k] += nzv[j]*αxj
171+
rvj = rv[j]
172+
C[rvj, k] = muladd(nzv[j], αxj, C[rvj, k])
172173
end
173174
end
174175
end
@@ -184,15 +185,16 @@ function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β)
184185
if α isa Bool && !α
185186
return
186187
end
188+
C0 = zero(eltype(C)) # Pre-allocate for BigFloat/BigInt etc
187189
B = _fix_size(B, mB, nB)
188190
C = _fix_size(C, mC, nC)
189191
for k in Cax2
190192
@inbounds for col in Aax2
191-
tmp = zero(eltype(C))
193+
tmp = C0
192194
for j in nzrange(A, col)
193-
tmp += tfun(nzv[j])*B[rv[j],k]
195+
tmp = muladd(tfun(nzv[j]), B[rv[j], k], tmp)
194196
end
195-
C[col,k] += α isa Bool ? tmp : tmp * α
197+
C[col, k] = α isa Bool ? tmp + C[col, k] : muladd(tmp, α, C[col, k])
196198
end
197199
end
198200
end
@@ -224,7 +226,8 @@ function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2
224226
Aiα = α isa Bool ? nzv[k] : nzv[k] * α
225227
rvk = rv[k]
226228
@simd for multivec_row in Xax1
227-
C[multivec_row, col] += X[multivec_row, rvk] * Aiα
229+
C[multivec_row, col] = muladd(X[multivec_row, rvk], Aiα,
230+
C[multivec_row, col])
228231
end
229232
end
230233
end
@@ -240,12 +243,17 @@ function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::S
240243
end
241244
C = _fix_size(C, mC, nC)
242245
X = _fix_size(X, mX, nX)
243-
for multivec_row in Xax1, col in Cax2
244-
@inbounds for k in nzrange(A, col)
245-
C[multivec_row, col] +=
246-
isa Bool ? X[multivec_row, rv[k]] * nzv[k] :
247-
X[multivec_row, rv[k]] * nzv[k] * α)
246+
@inbounds for multivec_row in Xax1, col in Cax2
247+
nzrng = nzrange(A, col)
248+
if isempty(nzrng)
249+
continue
248250
end
251+
tmp = C[multivec_row, col]
252+
for k in nzrng
253+
tmp = muladd(X[multivec_row, rv[k]],
254+
isa Bool ? nzv[k] : nzv[k] * α), tmp)
255+
end
256+
C[multivec_row, col] = tmp
249257
end
250258
end
251259

@@ -265,7 +273,7 @@ function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B
265273
Biα = α isa Bool ? tfun(nzv[k]) : tfun(nzv[k]) * α
266274
rvk = rv[k]
267275
@simd for multivec_col in Aax1
268-
C[multivec_col, rvk] += A[multivec_col, col] * Biα
276+
C[multivec_col, rvk] = muladd(A[multivec_col, col], Biα, C[multivec_col, rvk])
269277
end
270278
end
271279
end

0 commit comments

Comments
 (0)