Skip to content

Commit b1fc176

Browse files
Add polyester for SpMV kernels
1 parent 7aab96d commit b1fc176

2 files changed

Lines changed: 75 additions & 3 deletions

File tree

src/SparseMatricesCSR.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ export SymSparseMatrixCSR
1414
export sparsecsr, symsparsecsr
1515
export colvals, getBi, getoffset
1616

17+
import Polyester: @batch
18+
import Atomix: @atomic
19+
1720
include("SparseMatrixCSR.jl")
1821

1922
include("SymSparseMatrixCSR.jl")

src/SparseMatrixCSR.jl

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,14 +319,22 @@ If `pred` not given, it counts the number of `true` values.
319319
count(pred, S::SparseMatrixCSR) = count(pred, nzvalview(S))
320320
count(S::SparseMatrixCSR) = count(i->true, nzvalview(S))
321321

322-
function mul!(y::AbstractVector,A::SparseMatrixCSR,v::AbstractVector, α::Number, β::Number)
322+
function mul!(y::AbstractVector,A::Adjoint{T, <:SparseMatrixCSR},v::AbstractVector) where T
323+
if Threads.nthreads() > 1
324+
tmul!(y, A, v)
325+
else
326+
smul!(y, A, v)
327+
end
328+
end
329+
330+
function smul!(y::AbstractVector,A::SparseMatrixCSR,v::AbstractVector, α::Number, β::Number)
323331
A.n == size(v, 1) || throw(DimensionMismatch())
324332
A.m == size(y, 1) || throw(DimensionMismatch())
325333
if β != 1
326334
β != 0 ? rmul!(y, β) : fill!(y, zero(eltype(y)))
327335
end
328336
o = getoffset(A)
329-
for row = 1:size(y, 1)
337+
@batch for row = 1:size(y, 1)
330338
@inbounds for nz in nzrange(A,row)
331339
col = A.colval[nz]+o
332340
y[row] += A.nzval[nz]*v[col]*α
@@ -335,7 +343,31 @@ function mul!(y::AbstractVector,A::SparseMatrixCSR,v::AbstractVector, α::Number
335343
return y
336344
end
337345

338-
function mul!(y::AbstractVector,A::SparseMatrixCSR,v::AbstractVector)
346+
function tmul!(y::AbstractVector,A::SparseMatrixCSR,v::AbstractVector, α::Number, β::Number)
347+
A.n == size(v, 1) || throw(DimensionMismatch())
348+
A.m == size(y, 1) || throw(DimensionMismatch())
349+
if β != 1
350+
β != 0 ? rmul!(y, β) : fill!(y, zero(eltype(y)))
351+
end
352+
o = getoffset(A)
353+
@batch for row = 1:size(y, 1)
354+
@inbounds for nz in nzrange(A,row)
355+
col = A.colval[nz]+o
356+
y[row] += A.nzval[nz]*v[col]*α
357+
end
358+
end
359+
return y
360+
end
361+
362+
function mul!(y::AbstractVector,A::SparseMatrixCSR,v::AbstractVector) where T
363+
if Threads.nthreads() > 1
364+
tmul!(y, A, v)
365+
else
366+
smul!(y, A, v)
367+
end
368+
end
369+
370+
function smul!(y::AbstractVector,A::SparseMatrixCSR,v::AbstractVector)
339371
A.n == size(v, 1) || throw(DimensionMismatch())
340372
A.m == size(y, 1) || throw(DimensionMismatch())
341373
fill!(y, zero(eltype(y)))
@@ -349,9 +381,31 @@ function mul!(y::AbstractVector,A::SparseMatrixCSR,v::AbstractVector)
349381
return y
350382
end
351383

384+
function tmul!(y::AbstractVector,A::SparseMatrixCSR,v::AbstractVector)
385+
A.n == size(v, 1) || throw(DimensionMismatch())
386+
A.m == size(y, 1) || throw(DimensionMismatch())
387+
fill!(y, zero(eltype(y)))
388+
o = getoffset(A)
389+
@batch for row = 1:size(y, 1)
390+
@inbounds for nz in nzrange(A,row)
391+
col = A.colval[nz]+o
392+
y[row] += A.nzval[nz]*v[col]
393+
end
394+
end
395+
return y
396+
end
397+
352398
*(A::SparseMatrixCSR, v::Vector) = (y = similar(v,size(A,1));mul!(y,A,v))
353399

354400
function mul!(y::AbstractVector,A::Adjoint{T, <:SparseMatrixCSR},v::AbstractVector) where T
401+
if Threads.nthreads() > 1
402+
tmul!(y, A, v)
403+
else
404+
smul!(y, A, v)
405+
end
406+
end
407+
408+
function smul!(y::AbstractVector,A::Adjoint{T, <:SparseMatrixCSR},v::AbstractVector) where T
355409
P = A.parent
356410
P.n == size(y, 1) || throw(DimensionMismatch())
357411
P.m == size(v, 1) || throw(DimensionMismatch())
@@ -366,6 +420,21 @@ function mul!(y::AbstractVector,A::Adjoint{T, <:SparseMatrixCSR},v::AbstractVect
366420
return y
367421
end
368422

423+
function tmul!(y::AbstractVector,A::Adjoint{T, <:SparseMatrixCSR},v::AbstractVector) where T
424+
P = A.parent
425+
P.n == size(y, 1) || throw(DimensionMismatch())
426+
P.m == size(v, 1) || throw(DimensionMismatch())
427+
fill!(y,zero(eltype(y)))
428+
o = getoffset(P)
429+
@batch for row = 1:size(P, 1)
430+
for nz in nzrange(P,row)
431+
col = P.colval[nz]+o
432+
y[col] += P.nzval[nz]*v[row]
433+
end
434+
end
435+
return y
436+
end
437+
369438
*(A::Adjoint{T, <:SparseMatrixCSR}, v::AbstractVector) where T = (y = similar(v, promote_type(eltype(v),T), size(A,1)); mul!(y, A, v))
370439

371440
function show(io::IO, ::MIME"text/plain", S::SparseMatrixCSR)

0 commit comments

Comments
 (0)