@@ -320,13 +320,37 @@ count(pred, S::SparseMatrixCSR) = count(pred, nzvalview(S))
320320count (S:: SparseMatrixCSR ) = count (i-> true , nzvalview (S))
321321
322322function mul! (y:: AbstractVector ,A:: SparseMatrixCSR ,v:: AbstractVector , α:: Number , β:: Number )
323+ if Threads. nthreads () > 1
324+ tmul! (y, A, v, α, β)
325+ else
326+ smul! (y, A, v, α, β)
327+ end
328+ end
329+
330+ function smul! (y:: AbstractVector ,A:: SparseMatrixCSR ,v:: AbstractVector , α:: Number , β:: Number )
323331 A. n == size (v, 1 ) || throw (DimensionMismatch ())
324332 A. m == size (y, 1 ) || throw (DimensionMismatch ())
325333 if β != 1
326334 β != 0 ? rmul! (y, β) : fill! (y, zero (eltype (y)))
327335 end
328336 o = getoffset (A)
329- for row = 1 : size (y, 1 )
337+ @batch for row = 1 : size (y, 1 )
338+ @inbounds for nz in nzrange (A,row)
339+ col = A. colval[nz]+ o
340+ y[row] += A. nzval[nz]* v[col]* α
341+ end
342+ end
343+ return y
344+ end
345+
346+ function tmul! (y:: AbstractVector ,A:: SparseMatrixCSR ,v:: AbstractVector , α:: Number , β:: Number )
347+ A. n == size (v, 1 ) || throw (DimensionMismatch ())
348+ A. m == size (y, 1 ) || throw (DimensionMismatch ())
349+ if β != 1
350+ β != 0 ? rmul! (y, β) : fill! (y, zero (eltype (y)))
351+ end
352+ o = getoffset (A)
353+ @batch for row = 1 : size (y, 1 )
330354 @inbounds for nz in nzrange (A,row)
331355 col = A. colval[nz]+ o
332356 y[row] += A. nzval[nz]* v[col]* α
@@ -336,6 +360,14 @@ function mul!(y::AbstractVector,A::SparseMatrixCSR,v::AbstractVector, α::Number
336360end
337361
338362function mul! (y:: AbstractVector ,A:: SparseMatrixCSR ,v:: AbstractVector )
363+ if Threads. nthreads () > 1
364+ tmul! (y, A, v)
365+ else
366+ smul! (y, A, v)
367+ end
368+ end
369+
370+ function smul! (y:: AbstractVector ,A:: SparseMatrixCSR ,v:: AbstractVector )
339371 A. n == size (v, 1 ) || throw (DimensionMismatch ())
340372 A. m == size (y, 1 ) || throw (DimensionMismatch ())
341373 fill! (y, zero (eltype (y)))
@@ -349,9 +381,31 @@ function mul!(y::AbstractVector,A::SparseMatrixCSR,v::AbstractVector)
349381 return y
350382end
351383
384+ function tmul! (y:: AbstractVector ,A:: SparseMatrixCSR ,v:: AbstractVector )
385+ A. n == size (v, 1 ) || throw (DimensionMismatch ())
386+ A. m == size (y, 1 ) || throw (DimensionMismatch ())
387+ fill! (y, zero (eltype (y)))
388+ o = getoffset (A)
389+ @batch for row = 1 : size (y, 1 )
390+ @inbounds for nz in nzrange (A,row)
391+ col = A. colval[nz]+ o
392+ y[row] += A. nzval[nz]* v[col]
393+ end
394+ end
395+ return y
396+ end
397+
352398* (A:: SparseMatrixCSR , v:: Vector ) = (y = similar (v,size (A,1 ));mul! (y,A,v))
353399
354- function mul! (y:: AbstractVector ,A:: Adjoint{T, <:SparseMatrixCSR} ,v:: AbstractVector ) where T
400+ function mul! (y:: AbstractVector ,A:: Adjoint{<:Any, <:SparseMatrixCSR} ,v:: AbstractVector )
401+ if Threads. nthreads () > 1
402+ tmul! (y, A, v)
403+ else
404+ smul! (y, A, v)
405+ end
406+ end
407+
408+ function smul! (y:: AbstractVector ,A:: Adjoint{<:Any, <:SparseMatrixCSR} ,v:: AbstractVector )
355409 P = A. parent
356410 P. n == size (y, 1 ) || throw (DimensionMismatch ())
357411 P. m == size (v, 1 ) || throw (DimensionMismatch ())
@@ -366,6 +420,21 @@ function mul!(y::AbstractVector,A::Adjoint{T, <:SparseMatrixCSR},v::AbstractVect
366420 return y
367421end
368422
423+ function tmul! (y:: AbstractVector ,A:: Adjoint{<:Any, <:SparseMatrixCSR} ,v:: AbstractVector )
424+ P = A. parent
425+ P. n == size (y, 1 ) || throw (DimensionMismatch ())
426+ P. m == size (v, 1 ) || throw (DimensionMismatch ())
427+ fill! (y,zero (eltype (y)))
428+ o = getoffset (P)
429+ @batch for row = 1 : size (P, 1 )
430+ for nz in nzrange (P,row)
431+ col = P. colval[nz]+ o
432+ @atomic y[col] += P. nzval[nz]* v[row]
433+ end
434+ end
435+ return y
436+ end
437+
369438* (A:: Adjoint{T, <:SparseMatrixCSR} , v:: AbstractVector ) where T = (y = similar (v, promote_type (eltype (v),T), size (A,1 )); mul! (y, A, v))
370439
371440function show (io:: IO , :: MIME"text/plain" , S:: SparseMatrixCSR )
0 commit comments