Skip to content

Commit 2dcf870

Browse files
authored
Merge pull request #18 from gridap/adding_set_index
Adding setindex!
2 parents 1cfa0c8 + 640e3c9 commit 2dcf870

5 files changed

Lines changed: 55 additions & 1 deletion

File tree

src/SparseMatricesCSR.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using SparseArrays
44
using LinearAlgebra
55
using SuiteSparse
66

7-
import Base: convert, copy, size, getindex, show, count, *, IndexStyle
7+
import Base: convert, copy, size, getindex, setindex!, show, count, *, IndexStyle
88
import LinearAlgebra: mul!, lu, lu!
99
import SparseArrays: nnz, getnzval, nonzeros, nzrange
1010
import SparseArrays: findnz, rowvals, getnzval, issparse

src/SparseMatrixCSR.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,26 @@ function getindex(A::SparseMatrixCSR{Bi,T}, i0::Integer, i1::Integer) where {Bi,
166166
k = searchsortedfirst(colvals(A), i1o, r1, r2, Base.Order.Forward)
167167
((k > r2) || (colvals(A)[k] != i1o)) ? zero(T) : nonzeros(A)[k]
168168
end
169+
function setindex!(A::SparseMatrixCSR{Bi,Tv,Ti}, _v, _i0::Integer, _i1::Integer) where {Bi,Tv,Ti}
170+
errmsg = "Trying to set an entry outside sparsity pattern"
171+
v = convert(Tv,_v)
172+
i0 = convert(Ti,_i0)
173+
i1 = convert(Ti,_i1)
174+
if !(1 <= i0 <= size(A, 1) && 1 <= i1 <= size(A, 2)); throw(BoundsError()); end
175+
o = getoffset(A)
176+
r1 = Int(getrowptr(A)[i0]+o)
177+
r2 = Int(getrowptr(A)[i0+1]-Bi)
178+
(r1 > r2) && throw(ArgumentError(errmsg))
179+
i1o = i1-o
180+
k = searchsortedfirst(colvals(A), i1o, r1, r2, Base.Order.Forward)
181+
if ((k > r2) || (colvals(A)[k] != i1o))
182+
throw(ArgumentError(errmsg))
183+
end
184+
A.nzval[k]=v
185+
end
186+
187+
188+
169189

170190
getrowptr(S::SparseMatrixCSR) = S.rowptr
171191
getnzval(S::SparseMatrixCSR) = S.nzval

src/SymSparseMatrixCSR.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ IndexStyle(::Type{<:SymSparseMatrixCSR}) = IndexCartesian()
5252
function getindex(A::SymSparseMatrixCSR, x::Integer, y::Integer)
5353
getindex(A.uppertrian,min(x,y),max(x,y))
5454
end
55+
function setindex!(A::SymSparseMatrixCSR, v, x::Integer, y::Integer)
56+
setindex!(A.uppertrian,v,min(x,y),max(x,y))
57+
end
5558

5659
getrowptr(S::SymSparseMatrixCSR) = getrowptr(S.uppertrian)
5760
getnzval(S::SymSparseMatrixCSR) = getnzval(S.uppertrian)

test/SparseMatrixCSR.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,27 @@ function test_csr(Bi,Tv,Ti)
2525
@test copy(CSR) == CSC
2626
end
2727
CSR = sparsecsr(Val(Bi),I,J,V)
28+
show(CSR)
2829
@test CSR == CSC
2930
@test copy(CSR) == CSC
3031
@test eltype(CSR) == Tv
3132
@test isa(CSR,SparseMatrixCSR{Bi,Tv,Ti})
3233

34+
for i=1:size(CSR,1)
35+
for j=1:size(CSR,2)
36+
if (i,j) in zip(I,J)
37+
CSR[i,j] = eltype(V)(i+j)
38+
@test CSR[i,j] eltype(V)(i+j)
39+
else
40+
try
41+
CSR[i,j] = eltype(V)(i+j)
42+
catch e
43+
@test isa(e,ArgumentError)
44+
end
45+
end
46+
end
47+
end
48+
3349
CSC = sparse(I,J,V,maxrows,maxcols)
3450
if Bi == 1
3551
CSR = sparsecsr(I,J,V,maxrows,maxcols)

test/SymSparseMatrixCSR.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,21 @@ function test_csr(Bi,Tv,Ti)
3232
@test eltype(CSR) == Tv
3333
@test isa(CSR,SymSparseMatrixCSR{Bi,Tv,Ti})
3434

35+
for i=1:size(CSR,1)
36+
for j=1:size(CSR,2)
37+
if (i,j) in zip(I,J)
38+
CSR[i,j] = eltype(V)(i+j)
39+
@test CSR[i,j] eltype(V)(i+j)
40+
else
41+
try
42+
CSR[i,j] = eltype(V)(i+j)
43+
catch e
44+
@test isa(e,ArgumentError)
45+
end
46+
end
47+
end
48+
end
49+
3550
CSC = sparse(I,J,V,maxrows,maxcols)
3651
if Bi == 1
3752
CSR = symsparsecsr(I_up,J_up,V_up,maxrows,maxcols)

0 commit comments

Comments
 (0)