|
| 1 | +# device-level indexing |
| 2 | +using SparseArrays: nonzeroinds, nonzeros, nnz, getcolptr |
| 3 | +using Base: @propagate_inbounds |
| 4 | + |
| 5 | +Base.IndexStyle(::Type{GPUSparseDeviceVector}) = Base.IndexLinear() |
| 6 | + |
| 7 | +# Scalar indexing |
| 8 | +## Adapted from SparseArrays.AbstractSparseVector |
| 9 | + |
| 10 | +@propagate_inbounds function Base.getindex(v::GPUSparseDeviceVector{Tv,Ti}, i::Integer) where {Tv,Ti} |
| 11 | + @boundscheck checkbounds(v, i) |
| 12 | + m = nnz(v) |
| 13 | + nzind = nonzeroinds(v) |
| 14 | + nzval = nonzeros(v) |
| 15 | + |
| 16 | + ii = searchsortedfirst(nzind, convert(Ti, i)) |
| 17 | + (ii <= m && nzind[ii] == i) ? nzval[ii] : zero(Tv) |
| 18 | +end |
| 19 | + |
| 20 | +# TODO: Logical indexing |
| 21 | + |
| 22 | +# Indexing by colon not implemented. Non-scalar indexing would allocate in device code |
| 23 | +@propagate_inbounds Base.getindex(A::AbstractGPUSparseDeviceMatrix, I::Tuple{Integer,Integer}) = getindex(A, I[1], I[2]) |
| 24 | + |
| 25 | +## Adapted logic from SparseArrays.AbstractSparseMatrixCSC |
| 26 | +@propagate_inbounds function Base.getindex(A::GPUSparseDeviceMatrixCSC{Tv,Ti}, i::Integer, j::Integer) where {Tv,Ti} |
| 27 | + @boundscheck checkbounds(A, i, j) |
| 28 | + colPtr, rowVal, nzVal = getcolptr(A), rowvals(A), nonzeros(A) |
| 29 | + |
| 30 | + # Range of possible row indices |
| 31 | + rl = convert(Ti, @inbounds colPtr[j]) |
| 32 | + rr = convert(Ti, @inbounds colPtr[j+1] - 1) |
| 33 | + (rl > rr) && return zero(Tv) |
| 34 | + |
| 35 | + # possible_row = @view rowVal[rl:rr] |
| 36 | + ii = searchsortedfirst(rowVal, convert(Ti, i), rl, rr, Base.Order.Forward) |
| 37 | + (ii <= nnz(A) && rowVal[ii] == i) ? nzVal[ii] : zero(Tv) |
| 38 | +end |
| 39 | + |
| 40 | +@propagate_inbounds function Base.getindex(A::GPUSparseDeviceMatrixCSR{Tv,Ti}, i::Integer, j::Integer) where {Tv,Ti} |
| 41 | + @boundscheck checkbounds(A, i, j) |
| 42 | + rowPtr, colVal, nzVal = A.rowPtr, A.colVal, A.nzVal |
| 43 | + |
| 44 | + # Range of possible col indices |
| 45 | + rt = convert(Ti, @inbounds rowPtr[i]) |
| 46 | + rb = convert(Ti, @inbounds rowPtr[i+1] - 1) |
| 47 | + (rt > rb) && return zero(Tv) |
| 48 | + |
| 49 | + # possible_col = @view colVal[rt:rb] |
| 50 | + jj = searchsortedfirst(colVal, convert(Ti, j), rt, rb, Base.Order.Forward) |
| 51 | + (jj <= nnz(A) && colVal[jj] == j) ? nzVal[jj] : zero(Tv) |
| 52 | +end |
| 53 | + |
| 54 | +## Adapted from CUDA.jl/blob/lib/cusparse/src/array.jl#L490 |
| 55 | +# FIXME: Currently not correct |
| 56 | +@propagate_inbounds function Base.getindex(A::GPUSparseDeviceMatrixCOO{Tv,Ti}, i::Integer, j::Integer) where {Tv,Ti} |
| 57 | + # COO in CUDA is assumed to be sorted by row: https://docs.nvidia.com/cuda/cusparse/storage-formats.html?highlight=coo#coordinate-coo |
| 58 | + # @boundscheck checkbounds(A, i, j) |
| 59 | + rowInd, colInd, nzVal = A.rowInd, A.colInd, A.nzVal |
| 60 | + |
| 61 | + # Looking for the range s.t. rowInd[r1:r2] .== i |
| 62 | + rl = searchsortedfirst(rowInd, i) |
| 63 | + (rl > nnz(A) || rowInd[rl] > i) && return 42 |
| 64 | + rr = min(searchsortedfirst(rowInd, i+1, Base.Order.Forward), nnz(A)) # searchsortedlast didn't behave as expected |
| 65 | + # FIXME: colInd isn't sorted |
| 66 | + jj = searchsortedfirst(colInd, j, rl, rr, Base.Order.Forward) |
| 67 | + (jj > rr || jj == nnz(A) + 1 || colInd[jj] > j) && return zero(Tv) |
| 68 | + |
| 69 | + return jj |
| 70 | +end |
| 71 | + |
| 72 | +# TODO: Support BSR format |
0 commit comments