@@ -13,7 +13,7 @@ using GPUArrays
1313using Adapt
1414using SparseArrays, LinearAlgebra
1515
16- import GPUArrays: dense_array_type
16+ import GPUArrays: dense_array_type, GPUSparseMatrixCSC, GPUSparseMatrixCSR
1717
1818import KernelAbstractions
1919import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config
@@ -150,6 +150,9 @@ mutable struct JLSparseMatrixCSC{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSC
150150 new {Tv, Ti} (colPtr, rowVal, nzVal, dims, length (nzVal))
151151 end
152152end
153+ function GPUSparseMatrixCSC (colPtr:: JLArray{Ti, 1} , rowVal:: JLArray{Ti, 1} , nzVal:: JLArray{Tv, 1} , dims:: NTuple{2,<:Integer} ) where {Tv, Ti <: Integer }
154+ return JLSparseMatrixCSC (colPtr, rowVal, nzVal, dims)
155+ end
153156function JLSparseMatrixCSC (colPtr:: JLArray{Ti, 1} , rowVal:: JLArray{Ti, 1} , nzVal:: JLArray{Tv, 1} , dims:: NTuple{2,<:Integer} ) where {Tv, Ti <: Integer }
154157 return JLSparseMatrixCSC {Tv, Ti} (colPtr, rowVal, nzVal, dims)
155158end
181184function JLSparseMatrixCSR (rowPtr:: JLArray{Ti, 1} , colVal:: JLArray{Ti, 1} , nzVal:: JLArray{Tv, 1} , dims:: NTuple{2,<:Integer} ) where {Tv, Ti <: Integer }
182185 return JLSparseMatrixCSR {Tv, Ti} (rowPtr, colVal, nzVal, dims)
183186end
187+ function GPUSparseMatrixCSR (rowPtr:: JLArray{Ti, 1} , colVal:: JLArray{Ti, 1} , nzVal:: JLArray{Tv, 1} , dims:: NTuple{2,<:Integer} ) where {Tv, Ti <: Integer }
188+ return JLSparseMatrixCSR (rowPtr, colVal, nzVal, dims)
189+ end
184190function SparseArrays. SparseMatrixCSC (x:: JLSparseMatrixCSR )
185191 x_transpose = SparseMatrixCSC (size (x, 2 ), size (x, 1 ), Array (x. rowPtr), Array (x. colVal), Array (x. nzVal))
186192 return SparseMatrixCSC (transpose (x_transpose))
0 commit comments