-
Notifications
You must be signed in to change notification settings - Fork 34
Expand file tree
/
Copy patharray.jl
More file actions
70 lines (61 loc) · 2.16 KB
/
array.jl
File metadata and controls
70 lines (61 loc) · 2.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
export oneSparseMatrixCSR, oneSparseMatrixCSC, oneSparseMatrixCOO
abstract type oneAbstractSparseArray{Tv, Ti, N} <: AbstractSparseArray{Tv, Ti, N} end
const oneAbstractSparseVector{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 1}
const oneAbstractSparseMatrix{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 2}
mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
handle::matrix_handle_t
rowPtr::oneVector{Ti}
colVal::oneVector{Ti}
nzVal::oneVector{Tv}
dims::NTuple{2,Int}
nnz::Ti
end
mutable struct oneSparseMatrixCSC{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
handle::matrix_handle_t
colPtr::oneVector{Ti}
rowVal::oneVector{Ti}
nzVal::oneVector{Tv}
dims::NTuple{2,Int}
nnz::Ti
end
mutable struct oneSparseMatrixCOO{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
handle::matrix_handle_t
rowInd::oneVector{Ti}
colInd::oneVector{Ti}
nzVal::oneVector{Tv}
dims::NTuple{2,Int}
nnz::Ti
end
Base.length(A::oneAbstractSparseMatrix) = prod(A.dims)
Base.size(A::oneAbstractSparseMatrix) = A.dims
function Base.size(A::oneAbstractSparseMatrix, d::Integer)
if d == 1 || d == 2
return A.dims[d]
else
throw(ArgumentError("dimension must be 1 or 2, got $d"))
end
end
SparseArrays.nnz(A::oneAbstractSparseMatrix) = A.nnz
SparseArrays.nonzeros(A::oneAbstractSparseMatrix) = A.nzVal
for (gpu, cpu) in [:oneSparseMatrixCSR => :SparseMatrixCSC,
:oneSparseMatrixCSC => :SparseMatrixCSC,
:oneSparseMatrixCOO => :SparseMatrixCSC]
@eval Base.show(io::IOContext, x::$gpu) =
show(io, $cpu(x))
@eval function Base.show(io::IO, mime::MIME"text/plain", S::$gpu)
xnnz = nnz(S)
m, n = size(S)
print(io, m, "×", n, " ", typeof(S), " with ", xnnz, " stored ",
xnnz == 1 ? "entry" : "entries")
if !(m == 0 || n == 0)
println(io, ":")
io = IOContext(io, :typeinfo => eltype(S))
if ndims(S) == 1
show(io, $cpu(S))
else
# so that we get the nice Braille pattern
Base.print_array(io, $cpu(S))
end
end
end
end