We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 2c34bc4 commit 0a1ae38Copy full SHA for 0a1ae38
2 files changed
src/implementations/svd.jl
@@ -229,13 +229,13 @@ end
229
function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ}
230
ϵ = similar(USVᴴ[2], compute_error)
231
(U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg)
232
- return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, -one(eltype(ϵ)))
+ return compute_error ? (U, S, Vᴴ, collect(ϵ)[1]) : (U, S, Vᴴ, -one(eltype(ϵ)))
233
end
234
function svd_trunc!(A, USVᴴ::Nothing, alg::TruncatedAlgorithm; compute_error::Bool = true)
235
Tr = real(eltype(A))
236
- ϵ = compute_error ? zeros(Tr, 1) : zeros(Tr, 0)
+ ϵ = zeros(Tr, compute_error)
237
U, S, Vᴴ, ϵ = svd_trunc!(A, (USVᴴ, ϵ), alg)
238
- return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, -one(Tr))
+ return compute_error ? (U, S, Vᴴ, collect(ϵ)[1]) : (U, S, Vᴴ, -one(Tr))
239
240
241
# Diagonal logic
test/cuda/svd.jl
@@ -140,7 +140,7 @@ end
140
S₀ = svd_vals(hA)
141
r = k
142
143
- U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r))
+ U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r), compute_error=false)
144
@test length(S1.diag) == r
145
@test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1]
146
@test norm(A - U1 * S1 * V1ᴴ) ≈ ϵ1
@@ -149,7 +149,7 @@ end
149
s = 1 + sqrt(eps(real(T)))
150
trunc2 = trunctol(; atol = s * S₀[r + 1])
151
152
- U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc = trunctol(; atol = s * S₀[r + 1]))
+ U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc = trunctol(; atol = s * S₀[r + 1]), compute_error=false)
153
@test length(S2.diag) == r
154
@test U1 ≈ U2
155
@test parent(S1) ≈ parent(S2)
0 commit comments