Skip to content

Commit 8d7ab9c

Browse files
committed
Optimize for alpha being boolean
1 parent 5862843 commit 8d7ab9c

1 file changed

Lines changed: 22 additions & 5 deletions

File tree

src/linalg.jl

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,12 @@ function _spmatmul!(C, A, B, α, β)
124124
nzv = nonzeros(A)
125125
rv = rowvals(A)
126126
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
127+
if α isa Bool && !α
128+
return
129+
end
127130
for k in Cax2
128131
@inbounds for col in Aax2
129-
αxj = B[col,k] * α
132+
αxj = α isa Bool ? B[col,k] : B[col,k] * α
130133
for j in nzrange(A, col)
131134
C[rv[j], k] += nzv[j]*αxj
132135
end
@@ -141,13 +144,16 @@ function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β)
141144
nzv = nonzeros(A)
142145
rv = rowvals(A)
143146
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
147+
if α isa Bool && !α
148+
return
149+
end
144150
for k in Cax2
145151
@inbounds for col in Aax2
146152
tmp = zero(eltype(C))
147153
for j in nzrange(A, col)
148154
tmp += tfun(nzv[j])*B[rv[j],k]
149155
end
150-
C[col,k] += tmp * α
156+
C[col,k] += α isa Bool ? tmp : tmp * α
151157
end
152158
end
153159
end
@@ -170,8 +176,11 @@ function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2
170176
rv = rowvals(A)
171177
nzv = nonzeros(A)
172178
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
179+
if α isa Bool && !α
180+
return
181+
end
173182
@inbounds for col in Aax2, k in nzrange(A, col)
174-
Aiα = nzv[k] * α
183+
Aiα = α isa Bool ? nzv[k] : nzv[k] * α
175184
rvk = rv[k]
176185
@simd for multivec_row in Xax1
177186
C[multivec_row, col] += X[multivec_row, rvk] * Aiα
@@ -185,9 +194,14 @@ function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::S
185194
rv = rowvals(A)
186195
nzv = nonzeros(A)
187196
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
197+
if α isa Bool && !α
198+
return
199+
end
188200
for multivec_row in Xax1, col in Cax2
189201
@inbounds for k in nzrange(A, col)
190-
C[multivec_row, col] += X[multivec_row, rv[k]] * nzv[k] * α
202+
C[multivec_row, col] +=
203+
isa Bool ? X[multivec_row, rv[k]] * nzv[k] :
204+
X[multivec_row, rv[k]] * nzv[k] * α)
191205
end
192206
end
193207
end
@@ -199,8 +213,11 @@ function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B
199213
rv = rowvals(B)
200214
nzv = nonzeros(B)
201215
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
216+
if α isa Bool && !α
217+
return
218+
end
202219
@inbounds for col in Bax2, k in nzrange(B, col)
203-
Biα = tfun(nzv[k]) * α
220+
Biα = α isa Bool ? tfun(nzv[k]) : tfun(nzv[k]) * α
204221
rvk = rv[k]
205222
@simd for multivec_col in Aax1
206223
C[multivec_col, rvk] += A[multivec_col, col] * Biα

0 commit comments

Comments
 (0)