diff --git a/src/SparseMatricesCSR.jl b/src/SparseMatricesCSR.jl index 4ea4ef3..cdd3d9c 100644 --- a/src/SparseMatricesCSR.jl +++ b/src/SparseMatricesCSR.jl @@ -13,6 +13,7 @@ export SparseMatrixCSR export SymSparseMatrixCSR export sparsecsr, symsparsecsr export colvals, getBi, getoffset +export spzeroscsr import Polyester: @batch import Atomix: @atomic @@ -33,7 +34,7 @@ end # module # DONE getrowptr # DONE getnzval # DONE getcolval -# TODO spzeros +# DONE spzeros # TODO spdiagm # TODO blockdiag # TODO sprand diff --git a/src/SparseMatrixCSR.jl b/src/SparseMatrixCSR.jl index 838edbf..e636a16 100644 --- a/src/SparseMatrixCSR.jl +++ b/src/SparseMatrixCSR.jl @@ -103,6 +103,24 @@ sparsecsr(::Val{Bi},I,J,V,m,n) where Bi = SparseMatrixCSR{Bi}(transpose(sparse(J sparsecsr(::Val{Bi},I,J,V,m,n,combine) where Bi = SparseMatrixCSR{Bi}(transpose(sparse(J,I,V,n,m,combine))) dimlub(I) = isempty(I) ? 0 : Int(maximum(I)) +""" + spzeroscsr(args...) + +Crate a `SparseMatricesCSR` of all zeros +""" +spzeroscsr(m::Integer,n::Integer) = SparseMatrixCSR(transpose(spzeros(n,m))) +spzeroscsr(::Type{Tv},m::Integer,n::Integer) where Tv = SparseMatrixCSR(transpose(spzeros(Tv,n,m))) +spzeroscsr(::Type{Tv},::Type{Ti},m::Integer,n::Integer) where {Tv,Ti} = SparseMatrixCSR(transpose(spzeros(Tv,Ti,n,m))) +# de-splatting variants +spzeroscsr(sz::Tuple{Integer,Integer}) = SparseMatrixCSR(transpose(spzeros((sz[2],sz[1])))) +spzeroscsr(::Type{Tv},sz::Tuple{Integer,Integer}) where Tv = SparseMatrixCSR(transpose(spzeros((sz[2],sz[1])))) +spzeroscsr(::Type{Tv},::Type{Ti},sz::Tuple{Integer,Integer}) where {Ti,Tv} = SparseMatrixCSR(transpose(spzeros(Tv,Ti,(sz[2],sz[1])))) +# below methods require julia 1.10 or later +spzeroscsr(I::AbstractVector,J::AbstractVector) = SparseMatrixCSR(transpose(spzeros(J,I))) +spzeroscsr(I::AbstractVector,J::AbstractVector,m,n) = SparseMatrixCSR(transpose(spzeros(J,I,n,m))) +spzeroscsr(::Type{Tv},I::AbstractVector,J::AbstractVector,m,n) where Tv = SparseMatrixCSR(transpose(spzeros(Tv,J,I,n,m))) + + Base.convert(::Type{T},a::T) where T<:SparseMatrixCSR = a function Base.convert( ::Type{SparseMatrixCSR{Bi,Tv,Ti}},a::SparseMatrixCSR{Bi}) where {Bi,Tv,Ti} diff --git a/test/SparseMatrixCSR.jl b/test/SparseMatrixCSR.jl index 35b7403..b113102 100644 --- a/test/SparseMatrixCSR.jl +++ b/test/SparseMatrixCSR.jl @@ -139,6 +139,43 @@ function test_csr(Bi,Tv,Ti) @test nnz(A) == length(SparseArrays.nzvalview(A)) == 3 @test SparseArrays.nzvalview(A) == [4., 5, 6] end + + # spzeroscsr tests + csc = spzeros(10, 9) + csr = spzeroscsr(10, 9) + @test csc == csr + + csc = spzeros(Tv, 10, 9) + csr = spzeroscsr(Tv, 10, 9) + @test csc == csr + + csc = spzeros(Tv, Ti, 10, 9) + csr = spzeroscsr(Tv, Ti, 10, 9) + @test csc == csr + + csc = spzeros((10, 9)) + csr = spzeroscsr((10, 9)) + @test csc == csr + + csc = spzeros(Tv, (10, 9)) + csr = spzeroscsr(Tv, (10, 9)) + @test csc == csr + + csc = spzeros(Tv, Ti, (10, 9)) + csr = spzeroscsr(Tv, Ti, (10, 9)) + @test csc == csr + + csc = spzeros(I, J) + csr = spzeroscsr(I, J) + @test csc == csr + + csc = spzeros(I, J, maxrows, maxcols) + csr = spzeroscsr(I, J, maxrows, maxcols) + @test csc == csr + + csc = spzeros(Tv, I, J, maxrows, maxcols) + csr = spzeroscsr(Tv, I, J, maxrows, maxcols) + @test csc == csr end function test_lu(Bi,I,J,V)