Skip to content

Commit 7c2ec65

Browse files
alonsoC1skshyatt
authored andcommitted
Initial (correct) implmentation of getindex for device-side CSC, CSR and SparseVector
1 parent b6d1f98 commit 7c2ec65

4 files changed

Lines changed: 75 additions & 0 deletions

File tree

src/GPUArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ using KernelAbstractions
2020
# device functionality
2121
include("device/abstractarray.jl")
2222
include("device/sparse.jl")
23+
include("device/indexing.jl")
2324

2425
# host abstractions
2526
include("host/abstractarray.jl")

src/device/indexing.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# device-level indexing
2+
using SparseArrays: nonzeroinds, nonzeros, nnz, getcolptr
3+
using Base: @propagate_inbounds
4+
5+
Base.IndexStyle(::Type{GPUSparseDeviceVector}) = Base.IndexLinear()
6+
7+
# Scalar indexing
8+
## Adapted from SparseArrays.AbstractSparseVector
9+
10+
@propagate_inbounds function Base.getindex(v::GPUSparseDeviceVector{Tv,Ti}, i::Integer) where {Tv,Ti}
11+
@boundscheck checkbounds(v, i)
12+
m = nnz(v)
13+
nzind = nonzeroinds(v)
14+
nzval = nonzeros(v)
15+
16+
ii = searchsortedfirst(nzind, convert(Ti, i))
17+
(ii <= m && nzind[ii] == i) ? nzval[ii] : zero(Tv)
18+
end
19+
20+
# TODO: Logical indexing
21+
22+
# Indexing by colon not implemented. Non-scalar indexing would allocate in device code
23+
@propagate_inbounds Base.getindex(A::AbstractGPUSparseDeviceMatrix, I::Tuple{Integer,Integer}) = getindex(A, I[1], I[2])
24+
25+
## Adapted logic from SparseArrays.AbstractSparseMatrixCSC
26+
@propagate_inbounds function Base.getindex(A::GPUSparseDeviceMatrixCSC{Tv,Ti}, i::Integer, j::Integer) where {Tv,Ti}
27+
@boundscheck checkbounds(A, i, j)
28+
colPtr, rowVal, nzVal = getcolptr(A), rowvals(A), nonzeros(A)
29+
30+
# Range of possible row indices
31+
rl = convert(Ti, @inbounds colPtr[j])
32+
rr = convert(Ti, @inbounds colPtr[j+1] - 1)
33+
(rl > rr) && return zero(Tv)
34+
35+
# possible_row = @view rowVal[rl:rr]
36+
ii = searchsortedfirst(rowVal, convert(Ti, i), rl, rr, Base.Order.Forward)
37+
(ii <= nnz(A) && rowVal[ii] == i) ? nzVal[ii] : zero(Tv)
38+
end
39+
40+
@propagate_inbounds function Base.getindex(A::GPUSparseDeviceMatrixCSR{Tv,Ti}, i::Integer, j::Integer) where {Tv,Ti}
41+
@boundscheck checkbounds(A, i, j)
42+
rowPtr, colVal, nzVal = A.rowPtr, A.colVal, A.nzVal
43+
44+
# Range of possible col indices
45+
rt = convert(Ti, @inbounds rowPtr[i])
46+
rb = convert(Ti, @inbounds rowPtr[i+1] - 1)
47+
(rt > rb) && return zero(Tv)
48+
49+
# possible_col = @view colVal[rt:rb]
50+
jj = searchsortedfirst(colVal, convert(Ti, j), rt, rb, Base.Order.Forward)
51+
(jj <= nnz(A) && colVal[jj] == j) ? nzVal[jj] : zero(Tv)
52+
end
53+
54+
## Adapted from CUDA.jl/blob/lib/cusparse/src/array.jl#L490
55+
# FIXME: Currently not correct
56+
@propagate_inbounds function Base.getindex(A::GPUSparseDeviceMatrixCOO{Tv,Ti}, i::Integer, j::Integer) where {Tv,Ti}
57+
# COO in CUDA is assumed to be sorted by row: https://docs.nvidia.com/cuda/cusparse/storage-formats.html?highlight=coo#coordinate-coo
58+
# @boundscheck checkbounds(A, i, j)
59+
rowInd, colInd, nzVal = A.rowInd, A.colInd, A.nzVal
60+
61+
# Looking for the range s.t. rowInd[r1:r2] .== i
62+
rl = searchsortedfirst(rowInd, i)
63+
(rl > nnz(A) || rowInd[rl] > i) && return 42
64+
rr = min(searchsortedfirst(rowInd, i+1, Base.Order.Forward), nnz(A)) # searchsortedlast didn't behave as expected
65+
# FIXME: colInd isn't sorted
66+
jj = searchsortedfirst(colInd, j, rl, rr, Base.Order.Forward)
67+
(jj > rr || jj == nnz(A) + 1 || colInd[jj] > j) && return zero(Tv)
68+
69+
return jj
70+
end
71+
72+
# TODO: Support BSR format

src/device/sparse.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ Base.length(g::AbstractGPUSparseDeviceMatrix) = prod(g.dims)
106106
Base.size(g::AbstractGPUSparseDeviceMatrix) = g.dims
107107
SparseArrays.nnz(g::AbstractGPUSparseDeviceMatrix) = g.nnz
108108
SparseArrays.getnzval(g::AbstractGPUSparseDeviceMatrix) = g.nzVal
109+
# FIXME: Implement `rowvals` to explicitly say why it's not available for CSR format?
109110

110111
struct GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N, M, A} <: AbstractSparseArray{Tv, Ti, N}
111112
rowPtr::Vi

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ const init_worker_code = quote
1010

1111
TestSuite.sparse_types(::Type{<:JLArray}) = (JLSparseVector, JLSparseMatrixCSC, JLSparseMatrixCSR)
1212
TestSuite.sparse_types(::Type{<:Array}) = (SparseVector, SparseMatrixCSC)
13+
TestSuite.sparse_device_types(::Type{<:Array}) = (GPUSparseDeviceVector, GPUSparseDeviceMatrixCSC, GPUSparseDeviceMatrixCSR, GPUSparseDeviceMatrixCOO)
1314

1415
# Disable Float16-related tests until JuliaGPU/KernelAbstractions#600 is resolved
1516
if isdefined(JLArrays.KernelAbstractions, :POCL)

0 commit comments

Comments
 (0)