diff --git a/src/linalg.jl b/src/linalg.jl index 7436c68c..8df6b2b6 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -934,7 +934,7 @@ function _dot(x::AbstractVector, A::AbstractSparseMatrixCSC, y::AbstractVector, @inbounds for col in 1:n ycol = y[col] xcol = x[col] - if _isnotzero(ycol) && _isnotzero(xcol) + if _isnotzero(ycol) || _isnotzero(xcol) for k in rangefun(A, col) i = rvals[k] Aij = nzvals[k] diff --git a/test/linalg.jl b/test/linalg.jl index e7533255..af7b2ff5 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -843,21 +843,34 @@ end end @testset "generalized dot product" begin - for i = 1:5 - A = sprand(ComplexF64, 10, 15, 0.4) - Av = view(A, :, :) - x = sprand(ComplexF64, 10, 0.5) - y = sprand(ComplexF64, 15, 0.5) + A = sprand(ComplexF64, 10, 15, 1.0) + A15 = sprand(ComplexF64, 15, 15, 1.0) + Av = view(A, :, :) + vx = sprand(ComplexF64, 10, 0.5) + vy = sprand(ComplexF64, 15, 0.5) + vy2 = sprand(ComplexF64, 15, 0.5) + for (x, y, y2) in ((vx, vy, vy2), (Vector(vx), Vector(vy), Vector(vy2))) @test dot(x, A, y) ≈ dot(Vector(x), A, Vector(y)) ≈ (Vector(x)' * Matrix(A)) * Vector(y) @test dot(x, A, y) ≈ dot(x, Av, y) + @test dot(x, collect(A), y) ≈ dot(x, A, y) + @test dot(y, collect(A)', x) ≈ dot(y, A', x) + @test dot(y, transpose(collect(A)), x) ≈ dot(y, transpose(A), x) + @test dot(y, Hermitian(collect(A15)), y2) ≈ dot(y, Hermitian(A15), y2) + @test dot(y, Symmetric(collect(A15)), y2) ≈ dot(y, Symmetric(A15), y2) + B = BitMatrix(rand(Bool, 10, 15)) + @test dot(x, A, y) ≈ dot(x, Matrix(A), y) + @test_throws DimensionMismatch dot([x, x], A, y) + @test_throws DimensionMismatch dot(x, A, [y, y]) + @test iszero(dot(spzeros(length(x)), A, y)) end for (T, trans) in ((Float64, Symmetric), (ComplexF64, Symmetric), (ComplexF64, Hermitian)), uplo in (:U, :L) B = sprandn(T, 10, 10, 0.2) x = sprandn(T, 10, 0.4) - S = trans(B'B, uplo) - Sd = trans(Matrix(B'B), uplo) - @test dot(x, S, x) ≈ dot(x, Sd, x) ≈ dot(Vector(x), S, Vector(x)) ≈ dot(Vector(x), Sd, Vector(x)) + xd = Vector(x) + S = trans(B, uplo) + Sd = trans(Matrix(B), uplo) + @test dot(x, S, x) ≈ dot(x, Sd, x) ≈ dot(xd, S, xd) ≈ dot(xd, Sd, xd) end end