@@ -319,14 +319,22 @@ If `pred` not given, it counts the number of `true` values.
319319count (pred, S:: SparseMatrixCSR ) = count (pred, nzvalview (S))
320320count (S:: SparseMatrixCSR ) = count (i-> true , nzvalview (S))
321321
322- function mul! (y:: AbstractVector ,A:: SparseMatrixCSR ,v:: AbstractVector , α:: Number , β:: Number )
322+ function mul! (y:: AbstractVector ,A:: Adjoint{T, <:SparseMatrixCSR} ,v:: AbstractVector ) where T
323+ if Threads. nthreads () > 1
324+ tmul! (y, A, v)
325+ else
326+ smul! (y, A, v)
327+ end
328+ end
329+
330+ function smul! (y:: AbstractVector ,A:: SparseMatrixCSR ,v:: AbstractVector , α:: Number , β:: Number )
323331 A. n == size (v, 1 ) || throw (DimensionMismatch ())
324332 A. m == size (y, 1 ) || throw (DimensionMismatch ())
325333 if β != 1
326334 β != 0 ? rmul! (y, β) : fill! (y, zero (eltype (y)))
327335 end
328336 o = getoffset (A)
329- for row = 1 : size (y, 1 )
337+ @batch for row = 1 : size (y, 1 )
330338 @inbounds for nz in nzrange (A,row)
331339 col = A. colval[nz]+ o
332340 y[row] += A. nzval[nz]* v[col]* α
@@ -335,7 +343,31 @@ function mul!(y::AbstractVector,A::SparseMatrixCSR,v::AbstractVector, α::Number
335343 return y
336344end
337345
338- function mul! (y:: AbstractVector ,A:: SparseMatrixCSR ,v:: AbstractVector )
346+ function tmul! (y:: AbstractVector ,A:: SparseMatrixCSR ,v:: AbstractVector , α:: Number , β:: Number )
347+ A. n == size (v, 1 ) || throw (DimensionMismatch ())
348+ A. m == size (y, 1 ) || throw (DimensionMismatch ())
349+ if β != 1
350+ β != 0 ? rmul! (y, β) : fill! (y, zero (eltype (y)))
351+ end
352+ o = getoffset (A)
353+ @batch for row = 1 : size (y, 1 )
354+ @inbounds for nz in nzrange (A,row)
355+ col = A. colval[nz]+ o
356+ y[row] += A. nzval[nz]* v[col]* α
357+ end
358+ end
359+ return y
360+ end
361+
362+ function mul! (y:: AbstractVector ,A:: SparseMatrixCSR ,v:: AbstractVector ) where T
363+ if Threads. nthreads () > 1
364+ tmul! (y, A, v)
365+ else
366+ smul! (y, A, v)
367+ end
368+ end
369+
370+ function smul! (y:: AbstractVector ,A:: SparseMatrixCSR ,v:: AbstractVector )
339371 A. n == size (v, 1 ) || throw (DimensionMismatch ())
340372 A. m == size (y, 1 ) || throw (DimensionMismatch ())
341373 fill! (y, zero (eltype (y)))
@@ -349,9 +381,31 @@ function mul!(y::AbstractVector,A::SparseMatrixCSR,v::AbstractVector)
349381 return y
350382end
351383
384+ function tmul! (y:: AbstractVector ,A:: SparseMatrixCSR ,v:: AbstractVector )
385+ A. n == size (v, 1 ) || throw (DimensionMismatch ())
386+ A. m == size (y, 1 ) || throw (DimensionMismatch ())
387+ fill! (y, zero (eltype (y)))
388+ o = getoffset (A)
389+ @batch for row = 1 : size (y, 1 )
390+ @inbounds for nz in nzrange (A,row)
391+ col = A. colval[nz]+ o
392+ y[row] += A. nzval[nz]* v[col]
393+ end
394+ end
395+ return y
396+ end
397+
352398* (A:: SparseMatrixCSR , v:: Vector ) = (y = similar (v,size (A,1 ));mul! (y,A,v))
353399
354400function mul! (y:: AbstractVector ,A:: Adjoint{T, <:SparseMatrixCSR} ,v:: AbstractVector ) where T
401+ if Threads. nthreads () > 1
402+ tmul! (y, A, v)
403+ else
404+ smul! (y, A, v)
405+ end
406+ end
407+
408+ function smul! (y:: AbstractVector ,A:: Adjoint{T, <:SparseMatrixCSR} ,v:: AbstractVector ) where T
355409 P = A. parent
356410 P. n == size (y, 1 ) || throw (DimensionMismatch ())
357411 P. m == size (v, 1 ) || throw (DimensionMismatch ())
@@ -366,6 +420,21 @@ function mul!(y::AbstractVector,A::Adjoint{T, <:SparseMatrixCSR},v::AbstractVect
366420 return y
367421end
368422
423+ function tmul! (y:: AbstractVector ,A:: Adjoint{T, <:SparseMatrixCSR} ,v:: AbstractVector ) where T
424+ P = A. parent
425+ P. n == size (y, 1 ) || throw (DimensionMismatch ())
426+ P. m == size (v, 1 ) || throw (DimensionMismatch ())
427+ fill! (y,zero (eltype (y)))
428+ o = getoffset (P)
429+ @batch for row = 1 : size (P, 1 )
430+ for nz in nzrange (P,row)
431+ col = P. colval[nz]+ o
432+ y[col] += P. nzval[nz]* v[row]
433+ end
434+ end
435+ return y
436+ end
437+
369438* (A:: Adjoint{T, <:SparseMatrixCSR} , v:: AbstractVector ) where T = (y = similar (v, promote_type (eltype (v),T), size (A,1 )); mul! (y, A, v))
370439
371440function show (io:: IO , :: MIME"text/plain" , S:: SparseMatrixCSR )
0 commit comments