From 7a78eb04abebc644ae90c676514d9a65c5ea57f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateus=20Ara=C3=BAjo?= Date: Sat, 16 May 2026 15:27:51 +0200 Subject: [PATCH] fix 3-argument dot (#715) --- src/linalg.jl | 2 +- test/linalg.jl | 20 +++++++++++--------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index 0749e07e..fcdf1f20 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -1387,7 +1387,7 @@ function _dot(x::AbstractVector, A::AbstractSparseMatrixCSC, y::AbstractVector, @inbounds for col in axes(A,2) ycol = y[col] xcol = x[col] - if _isnotzero(ycol) && _isnotzero(xcol) + if _isnotzero(ycol) || _isnotzero(xcol) for k in rangefun(A, col) i = rvals[k] Aij = nzvals[k] diff --git a/test/linalg.jl b/test/linalg.jl index d6e21e8a..4f5fb50f 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -967,18 +967,20 @@ end end @testset "generalized dot product" begin - for i = 1:5 - A = sprand(ComplexF64, 10, 15, 0.4) - Av = view(A, :, :) - x = sprand(ComplexF64, 10, 0.5) - y = sprand(ComplexF64, 15, 0.5) + A = sprand(ComplexF64, 10, 15, 1.0) + A15 = sprand(ComplexF64, 15, 15, 1.0) + Av = view(A, :, :) + vx = sprand(ComplexF64, 10, 0.5) + vy = sprand(ComplexF64, 15, 0.5) + vy2 = sprand(ComplexF64, 15, 0.5) + for (x, y, y2) in ((vx, vy, vy2), (Vector(vx), Vector(vy), Vector(vy2))) @test dot(x, A, y) ≈ dot(Vector(x), A, Vector(y)) ≈ (Vector(x)' * Matrix(A)) * Vector(y) @test dot(x, A, y) ≈ dot(x, Av, y) @test dot(x, collect(A), y) ≈ dot(x, A, y) @test dot(y, collect(A)', x) ≈ dot(y, A', x) @test dot(y, transpose(collect(A)), x) ≈ dot(y, transpose(A), x) - @test dot(y, Hermitian(collect(A)' * collect(A)), y) ≈ dot(y, Hermitian(A' * A), y) - @test dot(y, Symmetric(collect(A)' * collect(A)), y) ≈ dot(y, Symmetric(A' * A), y) + @test dot(y, Hermitian(collect(A15)), y2) ≈ dot(y, Hermitian(A15), y2) + @test dot(y, Symmetric(collect(A15)), y2) ≈ dot(y, Symmetric(A15), y2) B = BitMatrix(rand(Bool, 10, 15)) @test dot(x, A, y) ≈ dot(x, Matrix(A), y) @test_throws DimensionMismatch dot([x, x], A, y) @@ -990,8 +992,8 @@ end B = sprandn(T, 10, 10, 0.2) x = sprandn(T, 10, 0.4) xd = Vector(x) - S = trans(B'B, uplo) - Sd = trans(Matrix(B'B), uplo) + S = trans(B, uplo) + Sd = trans(Matrix(B), uplo) @test dot(x, S, x) ≈ dot(x, Sd, x) ≈ dot(xd, S, xd) ≈ dot(xd, Sd, xd) end end