Skip to content

Commit 12919bd

Browse files
committed
Add JLArrays sparse constructors for CSC and CSR
1 parent 817326f commit 12919bd

1 file changed

Lines changed: 7 additions & 1 deletion

File tree

lib/JLArrays/src/JLArrays.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using GPUArrays
1313
using Adapt
1414
using SparseArrays, LinearAlgebra
1515

16-
import GPUArrays: dense_array_type
16+
import GPUArrays: dense_array_type, GPUSparseMatrixCSC, GPUSparseMatrixCSR
1717

1818
import KernelAbstractions
1919
import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config
@@ -150,6 +150,9 @@ mutable struct JLSparseMatrixCSC{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSC
150150
new{Tv, Ti}(colPtr, rowVal, nzVal, dims, length(nzVal))
151151
end
152152
end
153+
function GPUSparseMatrixCSC(colPtr::JLArray{Ti, 1}, rowVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
154+
return JLSparseMatrixCSC(colPtr, rowVal, nzVal, dims)
155+
end
153156
function JLSparseMatrixCSC(colPtr::JLArray{Ti, 1}, rowVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
154157
return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, dims)
155158
end
@@ -181,6 +184,9 @@ end
181184
function JLSparseMatrixCSR(rowPtr::JLArray{Ti, 1}, colVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
182185
return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, dims)
183186
end
187+
function GPUSparseMatrixCSR(rowPtr::JLArray{Ti, 1}, colVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
188+
return JLSparseMatrixCSR(rowPtr, colVal, nzVal, dims)
189+
end
184190
function SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSR)
185191
x_transpose = SparseMatrixCSC(size(x, 2), size(x, 1), Array(x.rowPtr), Array(x.colVal), Array(x.nzVal))
186192
return SparseMatrixCSC(transpose(x_transpose))

0 commit comments

Comments
 (0)