Skip to content

Commit 5b81329

Browse files
alonsoC1skshyatt
authored andcommitted
Fixes for COO indices. It works now
1 parent 7c2ec65 commit 5b81329

1 file changed

Lines changed: 25 additions & 11 deletions

File tree

src/device/indexing.jl

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ Base.IndexStyle(::Type{GPUSparseDeviceVector}) = Base.IndexLinear()
77
# Scalar indexing
88
## Adapted from SparseArrays.AbstractSparseVector
99

10-
@propagate_inbounds function Base.getindex(v::GPUSparseDeviceVector{Tv,Ti}, i::Integer) where {Tv,Ti}
10+
@propagate_inbounds function Base.getindex(
11+
v::GPUSparseDeviceVector{Tv,Ti},
12+
i::Integer,
13+
) where {Tv,Ti}
1114
@boundscheck checkbounds(v, i)
1215
m = nnz(v)
1316
nzind = nonzeroinds(v)
@@ -20,10 +23,17 @@ end
2023
# TODO: Logical indexing
2124

2225
# Indexing by colon not implemented. Non-scalar indexing would allocate in device code
23-
@propagate_inbounds Base.getindex(A::AbstractGPUSparseDeviceMatrix, I::Tuple{Integer,Integer}) = getindex(A, I[1], I[2])
26+
@propagate_inbounds Base.getindex(
27+
A::AbstractGPUSparseDeviceMatrix,
28+
I::Tuple{Integer,Integer},
29+
) = getindex(A, I[1], I[2])
2430

2531
## Adapted logic from SparseArrays.AbstractSparseMatrixCSC
26-
@propagate_inbounds function Base.getindex(A::GPUSparseDeviceMatrixCSC{Tv,Ti}, i::Integer, j::Integer) where {Tv,Ti}
32+
@propagate_inbounds function Base.getindex(
33+
A::GPUSparseDeviceMatrixCSC{Tv,Ti},
34+
i::Integer,
35+
j::Integer,
36+
) where {Tv,Ti}
2737
@boundscheck checkbounds(A, i, j)
2838
colPtr, rowVal, nzVal = getcolptr(A), rowvals(A), nonzeros(A)
2939

@@ -32,12 +42,15 @@ end
3242
rr = convert(Ti, @inbounds colPtr[j+1] - 1)
3343
(rl > rr) && return zero(Tv)
3444

35-
# possible_row = @view rowVal[rl:rr]
3645
ii = searchsortedfirst(rowVal, convert(Ti, i), rl, rr, Base.Order.Forward)
3746
(ii <= nnz(A) && rowVal[ii] == i) ? nzVal[ii] : zero(Tv)
3847
end
3948

40-
@propagate_inbounds function Base.getindex(A::GPUSparseDeviceMatrixCSR{Tv,Ti}, i::Integer, j::Integer) where {Tv,Ti}
49+
@propagate_inbounds function Base.getindex(
50+
A::GPUSparseDeviceMatrixCSR{Tv,Ti},
51+
i::Integer,
52+
j::Integer,
53+
) where {Tv,Ti}
4154
@boundscheck checkbounds(A, i, j)
4255
rowPtr, colVal, nzVal = A.rowPtr, A.colVal, A.nzVal
4356

@@ -46,27 +59,28 @@ end
4659
rb = convert(Ti, @inbounds rowPtr[i+1] - 1)
4760
(rt > rb) && return zero(Tv)
4861

49-
# possible_col = @view colVal[rt:rb]
5062
jj = searchsortedfirst(colVal, convert(Ti, j), rt, rb, Base.Order.Forward)
5163
(jj <= nnz(A) && colVal[jj] == j) ? nzVal[jj] : zero(Tv)
5264
end
5365

5466
## Adapted from CUDA.jl/blob/lib/cusparse/src/array.jl#L490
55-
# FIXME: Currently not correct
56-
@propagate_inbounds function Base.getindex(A::GPUSparseDeviceMatrixCOO{Tv,Ti}, i::Integer, j::Integer) where {Tv,Ti}
67+
@propagate_inbounds function Base.getindex(
68+
A::GPUSparseDeviceMatrixCOO{Tv,Ti},
69+
i::Integer,
70+
j::Integer,
71+
) where {Tv,Ti}
5772
# COO in CUDA is assumed to be sorted by row: https://docs.nvidia.com/cuda/cusparse/storage-formats.html?highlight=coo#coordinate-coo
58-
# @boundscheck checkbounds(A, i, j)
73+
@boundscheck checkbounds(A, i, j)
5974
rowInd, colInd, nzVal = A.rowInd, A.colInd, A.nzVal
6075

6176
# Looking for the range s.t. rowInd[r1:r2] .== i
6277
rl = searchsortedfirst(rowInd, i)
6378
(rl > nnz(A) || rowInd[rl] > i) && return 42
6479
rr = min(searchsortedfirst(rowInd, i+1, Base.Order.Forward), nnz(A)) # searchsortedlast didn't behave as expected
65-
# FIXME: colInd isn't sorted
6680
jj = searchsortedfirst(colInd, j, rl, rr, Base.Order.Forward)
6781
(jj > rr || jj == nnz(A) + 1 || colInd[jj] > j) && return zero(Tv)
6882

69-
return jj
83+
return nzVal[jj]
7084
end
7185

7286
# TODO: Support BSR format

0 commit comments

Comments
 (0)