Skip to content

Commit 6adadc8

Browse files
authored
fix 3-argument dot (#715)
1 parent ce5909e commit 6adadc8

2 files changed

Lines changed: 12 additions & 10 deletions

File tree

src/linalg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1387,7 +1387,7 @@ function _dot(x::AbstractVector, A::AbstractSparseMatrixCSC, y::AbstractVector,
13871387
@inbounds for col in axes(A,2)
13881388
ycol = y[col]
13891389
xcol = x[col]
1390-
if _isnotzero(ycol) && _isnotzero(xcol)
1390+
if _isnotzero(ycol) || _isnotzero(xcol)
13911391
for k in rangefun(A, col)
13921392
i = rvals[k]
13931393
Aij = nzvals[k]

test/linalg.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,18 +1001,20 @@ end
10011001
end
10021002

10031003
@testset "generalized dot product" begin
1004-
for i = 1:5
1005-
A = sprand(ComplexF64, 10, 15, 0.4)
1006-
Av = view(A, :, :)
1007-
x = sprand(ComplexF64, 10, 0.5)
1008-
y = sprand(ComplexF64, 15, 0.5)
1004+
A = sprand(ComplexF64, 10, 15, 1.0)
1005+
A15 = sprand(ComplexF64, 15, 15, 1.0)
1006+
Av = view(A, :, :)
1007+
vx = sprand(ComplexF64, 10, 0.5)
1008+
vy = sprand(ComplexF64, 15, 0.5)
1009+
vy2 = sprand(ComplexF64, 15, 0.5)
1010+
for (x, y, y2) in ((vx, vy, vy2), (Vector(vx), Vector(vy), Vector(vy2)))
10091011
@test dot(x, A, y) dot(Vector(x), A, Vector(y)) (Vector(x)' * Matrix(A)) * Vector(y)
10101012
@test dot(x, A, y) dot(x, Av, y)
10111013
@test dot(x, collect(A), y) dot(x, A, y)
10121014
@test dot(y, collect(A)', x) dot(y, A', x)
10131015
@test dot(y, transpose(collect(A)), x) dot(y, transpose(A), x)
1014-
@test dot(y, Hermitian(collect(A)' * collect(A)), y) dot(y, Hermitian(A' * A), y)
1015-
@test dot(y, Symmetric(collect(A)' * collect(A)), y) dot(y, Symmetric(A' * A), y)
1016+
@test dot(y, Hermitian(collect(A15)), y2) dot(y, Hermitian(A15), y2)
1017+
@test dot(y, Symmetric(collect(A15)), y2) dot(y, Symmetric(A15), y2)
10161018
B = BitMatrix(rand(Bool, 10, 15))
10171019
@test dot(x, A, y) dot(x, Matrix(A), y)
10181020
@test_throws DimensionMismatch dot([x, x], A, y)
@@ -1024,8 +1026,8 @@ end
10241026
B = sprandn(T, 10, 10, 0.2)
10251027
x = sprandn(T, 10, 0.4)
10261028
xd = Vector(x)
1027-
S = trans(B'B, uplo)
1028-
Sd = trans(Matrix(B'B), uplo)
1029+
S = trans(B, uplo)
1030+
Sd = trans(Matrix(B), uplo)
10291031
@test dot(x, S, x) dot(x, Sd, x) dot(xd, S, xd) dot(xd, Sd, xd)
10301032
end
10311033
end

0 commit comments

Comments
 (0)