Skip to content

Commit 60273e7

Browse files
authored
Backports to release-1.10 (#719)
- [x] #715
2 parents c3e31aa + 98be987 commit 60273e7

2 files changed

Lines changed: 22 additions & 9 deletions

File tree

src/linalg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -934,7 +934,7 @@ function _dot(x::AbstractVector, A::AbstractSparseMatrixCSC, y::AbstractVector,
934934
@inbounds for col in 1:n
935935
ycol = y[col]
936936
xcol = x[col]
937-
if _isnotzero(ycol) && _isnotzero(xcol)
937+
if _isnotzero(ycol) || _isnotzero(xcol)
938938
for k in rangefun(A, col)
939939
i = rvals[k]
940940
Aij = nzvals[k]

test/linalg.jl

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -843,21 +843,34 @@ end
843843
end
844844

845845
@testset "generalized dot product" begin
846-
for i = 1:5
847-
A = sprand(ComplexF64, 10, 15, 0.4)
848-
Av = view(A, :, :)
849-
x = sprand(ComplexF64, 10, 0.5)
850-
y = sprand(ComplexF64, 15, 0.5)
846+
A = sprand(ComplexF64, 10, 15, 1.0)
847+
A15 = sprand(ComplexF64, 15, 15, 1.0)
848+
Av = view(A, :, :)
849+
vx = sprand(ComplexF64, 10, 0.5)
850+
vy = sprand(ComplexF64, 15, 0.5)
851+
vy2 = sprand(ComplexF64, 15, 0.5)
852+
for (x, y, y2) in ((vx, vy, vy2), (Vector(vx), Vector(vy), Vector(vy2)))
851853
@test dot(x, A, y) dot(Vector(x), A, Vector(y)) (Vector(x)' * Matrix(A)) * Vector(y)
852854
@test dot(x, A, y) dot(x, Av, y)
855+
@test dot(x, collect(A), y) dot(x, A, y)
856+
@test dot(y, collect(A)', x) dot(y, A', x)
857+
@test dot(y, transpose(collect(A)), x) dot(y, transpose(A), x)
858+
@test dot(y, Hermitian(collect(A15)), y2) dot(y, Hermitian(A15), y2)
859+
@test dot(y, Symmetric(collect(A15)), y2) dot(y, Symmetric(A15), y2)
860+
B = BitMatrix(rand(Bool, 10, 15))
861+
@test dot(x, A, y) dot(x, Matrix(A), y)
862+
@test_throws DimensionMismatch dot([x, x], A, y)
863+
@test_throws DimensionMismatch dot(x, A, [y, y])
864+
@test iszero(dot(spzeros(length(x)), A, y))
853865
end
854866

855867
for (T, trans) in ((Float64, Symmetric), (ComplexF64, Symmetric), (ComplexF64, Hermitian)), uplo in (:U, :L)
856868
B = sprandn(T, 10, 10, 0.2)
857869
x = sprandn(T, 10, 0.4)
858-
S = trans(B'B, uplo)
859-
Sd = trans(Matrix(B'B), uplo)
860-
@test dot(x, S, x) dot(x, Sd, x) dot(Vector(x), S, Vector(x)) dot(Vector(x), Sd, Vector(x))
870+
xd = Vector(x)
871+
S = trans(B, uplo)
872+
Sd = trans(Matrix(B), uplo)
873+
@test dot(x, S, x) dot(x, Sd, x) dot(xd, S, xd) dot(xd, Sd, xd)
861874
end
862875
end
863876

0 commit comments

Comments
 (0)