Skip to content

Commit d7491b2

Browse files
committed
Add initial tests for direct sparse matrix constructors
1 parent 12919bd commit d7491b2

1 file changed

Lines changed: 59 additions & 0 deletions

File tree

test/testsuite/sparse.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
elseif sparse_AT <: AbstractSparseMatrix
1010
matrix(sparse_AT, eltypes)
1111
matrix_construction(sparse_AT, eltypes)
12+
direct_vector_construction(sparse_AT, eltypes)
1213
broadcasting_matrix(sparse_AT, eltypes)
1314
mapreduce_matrix(sparse_AT, eltypes)
1415
linalg(sparse_AT, eltypes)
@@ -151,6 +152,64 @@ function matrix_construction(AT, eltypes)
151152
end
152153
end
153154

155+
# Helper function to derive direct matrix formats:
156+
# Create colptr, rowval, nzval for m x n matrix with 3 values per column
157+
function csc_vectors(m::Int, n::Int, ::Type{ET}; I::Type{<:Integer}=Int32) where {ET}
158+
# Fixed, deterministic 3 nnz per column; random nz values
159+
colptr = Vector{I}(undef, n + 1)
160+
rowval = Vector{I}()
161+
nzval = Vector{ET}()
162+
163+
colptr[1] = I(1)
164+
nnz_acc = 0
165+
for j in 1:n
166+
# Magic numbers
167+
rows_j = sort(unique(mod.(j .+ (1, 7, 13), m) .+ 1))
168+
append!(rowval, I.(rows_j))
169+
append!(nzval, rand(ET, length(rows_j)))
170+
nnz_acc += length(rows_j)
171+
colptr[j + 1] = I(nnz_acc + 1)
172+
end
173+
return colptr, rowval, nzval
174+
end
175+
function csr_vectors(m::Int, n::Int, ::Type{ET}; I::Type{<:Integer}=Int32) where {ET}
176+
# Build CSC for (n, m), then interpret as CSR for (m, n)
177+
colptr_nm, rowval_nm, nzval_nm = csc_vectors(n, m, ET; I=I)
178+
rowptr = colptr_nm
179+
colind = rowval_nm
180+
nzval = nzval_nm
181+
return rowptr, colind, nzval
182+
end
183+
# Construct appropriate sparse arrays
184+
function construct_sparse_matrix(AT::Type{<:GPUArrays.AbstractGPUSparseMatrixCSC}, ::Type{ET}, m::Int, n::Int; I::Type{<:Integer}=Int32) where {ET}
185+
colptr, rowval, nzval = csc_vectors(m, n, ET; I=I)
186+
dense_AT = GPUArrays.dense_array_type(AT)
187+
d_colptr = dense_AT(colptr)
188+
d_rowval = dense_AT(rowval)
189+
d_nzval = dense_AT(nzval)
190+
return GPUSparseMatrixCSC(d_colptr, d_rowval, d_nzval, (m, n))
191+
end
192+
function construct_sparse_matrix(AT::Type{<:GPUArrays.AbstractGPUSparseMatrixCSR}, ::Type{ET}, m::Int, n::Int; I::Type{<:Integer}=Int32) where {ET}
193+
rowptr, colind, nzval = csr_vectors(m, n, ET; I=I)
194+
dense_AT = GPUArrays.dense_array_type(AT)
195+
d_rowptr = dense_AT(rowptr)
196+
d_colind = dense_AT(colind)
197+
d_nzval = dense_AT(nzval)
198+
return GPUSparseMatrixCSR(d_rowptr, d_colind, d_nzval, (m, n))
199+
end
200+
function direct_vector_construction(AT::Type{<:GPUArrays.AbstractGPUSparseMatrix}, eltypes)
201+
for ET in eltypes
202+
m = 25
203+
n = 35
204+
x = construct_sparse_matrix(AT, ET, m, n)
205+
@test x isa AT{ET}
206+
@test size(x) == (m, n)
207+
end
208+
end
209+
function direct_vector_construction(AT, eltypes)
210+
# NOP
211+
end
212+
154213
function broadcasting_vector(AT, eltypes)
155214
dense_AT = GPUArrays.dense_array_type(AT)
156215
for ET in eltypes

0 commit comments

Comments
 (0)