Skip to content

Commit 7d649de

Browse files
committed
Switch to svd_trunc and svd_trunc_no_error
1 parent 1fb8b6a commit 7d649de

12 files changed

Lines changed: 86 additions & 85 deletions

File tree

docs/src/user_interface/truncations.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,16 @@ combined_trunc = truncrank(10) & trunctol(; atol = 1e-6);
113113

114114
## Truncation Error
115115

116-
When using truncated decompositions such as [`svd_trunc_with_err`](@ref), [`eig_trunc`](@ref), or [`eigh_trunc`](@ref), an additional truncation error value is returned.
116+
When using truncated decompositions such as [`svd_trunc`](@ref), [`eig_trunc`](@ref), or [`eigh_trunc`](@ref), an additional truncation error value is returned.
117117
This error is defined as the 2-norm of the discarded singular values or eigenvalues, providing a measure of the approximation quality.
118-
For `svd_trunc_with_err` and `eigh_trunc`, this corresponds to the 2-norm difference between the original and the truncated matrix.
118+
For `svd_trunc` and `eigh_trunc`, this corresponds to the 2-norm difference between the original and the truncated matrix.
119119
For the case of `eig_trunc`, this interpretation does not hold because the norm of the non-unitary matrix of eigenvectors and its inverse also influence the approximation quality.
120120

121121

122122
For example:
123123
```jldoctest truncations; output=false
124124
using LinearAlgebra: norm
125-
U, S, Vᴴ, ϵ = svd_trunc_with_err(A; trunc=truncrank(2))
125+
U, S, Vᴴ, ϵ = svd_trunc(A; trunc=truncrank(2))
126126
norm(A - U * S * Vᴴ) ≈ ϵ # ϵ is the 2-norm of the discarded singular values
127127
128128
# output

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -170,15 +170,15 @@ for svd_f in (:svd_compact, :svd_full)
170170
end
171171
end
172172

173-
function ChainRulesCore.rrule(::typeof(svd_trunc_with_err!), A, USVᴴ, alg::TruncatedAlgorithm)
173+
function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlgorithm)
174174
Ac = copy_input(svd_compact, A)
175175
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
176176
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
177177
ϵ = truncation_error(diagview(USVᴴ[2]), ind)
178-
return (USVᴴ′..., ϵ), _make_svd_trunc_with_err_pullback(A, USVᴴ, ind)
178+
return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind)
179179
end
180-
function _make_svd_trunc_with_err_pullback(A, USVᴴ, ind)
181-
function svd_trunc_with_err_pullback(ΔUSVᴴϵ)
180+
function _make_svd_trunc_pullback(A, USVᴴ, ind)
181+
function svd_trunc_pullback(ΔUSVᴴϵ)
182182
ΔA = zero(A)
183183
ΔU, ΔS, ΔVᴴ, Δϵ = ΔUSVᴴϵ
184184
if !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ))
@@ -187,26 +187,26 @@ function _make_svd_trunc_with_err_pullback(A, USVᴴ, ind)
187187
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind)
188188
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
189189
end
190-
function svd_trunc_with_err_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
190+
function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
191191
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
192192
end
193-
return svd_trunc_with_err_pullback
193+
return svd_trunc_pullback
194194
end
195195

196-
function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlgorithm)
196+
function ChainRulesCore.rrule(::typeof(svd_trunc_no_error!), A, USVᴴ, alg::TruncatedAlgorithm)
197197
Ac = copy_input(svd_compact, A)
198198
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
199199
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
200-
return USVᴴ′, _make_svd_trunc_pullback(A, USVᴴ, ind)
200+
return USVᴴ′, _make_svd_trunc_no_error_pullback(A, USVᴴ, ind)
201201
end
202-
function _make_svd_trunc_pullback(A, USVᴴ, ind)
202+
function _make_svd_trunc_no_error_pullback(A, USVᴴ, ind)
203203
function svd_trunc_pullback(ΔUSVᴴ)
204204
ΔA = zero(A)
205205
ΔU, ΔS, ΔVᴴ = ΔUSVᴴ
206206
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind)
207207
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
208208
end
209-
function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
209+
function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
210210
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
211211
end
212212
return svd_trunc_pullback

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -303,14 +303,14 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
303303
return S_codual, svd_vals_adjoint
304304
end
305305

306-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_with_err), Any, MatrixAlgebraKit.AbstractAlgorithm}
307-
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_with_err)}, A_dA::CoDual, alg_dalg::CoDual)
306+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm}
307+
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual)
308308
# compute primal
309309
A_ = Mooncake.primal(A_dA)
310310
dA_ = Mooncake.tangent(A_dA)
311311
A, dA = arrayify(A_, dA_)
312312
alg = Mooncake.primal(alg_dalg)
313-
output = svd_trunc_with_err(A, alg)
313+
output = svd_trunc(A, alg)
314314
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
315315
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
316316
# pass). For many types this is done automatically when the forward step returns, but
@@ -319,7 +319,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_with_err)}, A_dA::CoDual, al
319319
function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real}
320320
Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual)
321321
dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual)
322-
abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc_with_err does not yet support non-zero tangent for the truncation error"
322+
abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc does not yet support non-zero tangent for the truncation error"
323323
U, dU = arrayify(Utrunc, dUtrunc_)
324324
S, dS = arrayify(Strunc, dStrunc_)
325325
Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_)
@@ -332,14 +332,14 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_with_err)}, A_dA::CoDual, al
332332
return output_codual, svd_trunc_adjoint
333333
end
334334

335-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm}
336-
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual)
335+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm}
336+
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual)
337337
# compute primal
338338
A_ = Mooncake.primal(A_dA)
339339
dA_ = Mooncake.tangent(A_dA)
340340
A, dA = arrayify(A_, dA_)
341341
alg = Mooncake.primal(alg_dalg)
342-
output = svd_trunc(A, alg)
342+
output = svd_trunc_no_error(A, alg)
343343
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
344344
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
345345
# pass). For many types this is done automatically when the forward step returns, but

src/MatrixAlgebraKit.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ export project_hermitian, project_antihermitian, project_isometric
1616
export project_hermitian!, project_antihermitian!, project_isometric!
1717
export qr_compact, qr_full, qr_null, lq_compact, lq_full, lq_null
1818
export qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!
19-
export svd_compact, svd_full, svd_vals, svd_trunc, svd_trunc_with_err
20-
export svd_compact!, svd_full!, svd_vals!, svd_trunc!, svd_trunc_with_err!
19+
export svd_compact, svd_full, svd_vals, svd_trunc, svd_trunc_no_error
20+
export svd_compact!, svd_full!, svd_vals!, svd_trunc!, svd_trunc_no_error!
2121
export eigh_full, eigh_vals, eigh_trunc
2222
export eigh_full!, eigh_vals!, eigh_trunc!
2323
export eig_full, eig_vals, eig_trunc

src/implementations/svd.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
copy_input(::typeof(svd_full), A::AbstractMatrix) = copy!(similar(A, float(eltype(A))), A)
44
copy_input(::typeof(svd_compact), A) = copy_input(svd_full, A)
55
copy_input(::typeof(svd_vals), A) = copy_input(svd_full, A)
6-
copy_input(::Union{typeof(svd_trunc), typeof(svd_trunc_with_err)}, A) = copy_input(svd_compact, A)
6+
copy_input(::Union{typeof(svd_trunc), typeof(svd_trunc_no_error)}, A) = copy_input(svd_compact, A)
77

88
copy_input(::typeof(svd_full), A::Diagonal) = copy(A)
99

@@ -89,7 +89,7 @@ end
8989
function initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::AbstractAlgorithm)
9090
return similar(A, real(eltype(A)), (min(size(A)...),))
9191
end
92-
function initialize_output(::Union{typeof(svd_trunc!), typeof(svd_trunc_with_err!)}, A, alg::TruncatedAlgorithm)
92+
function initialize_output(::Union{typeof(svd_trunc!), typeof(svd_trunc_no_error!)}, A, alg::TruncatedAlgorithm)
9393
return initialize_output(svd_compact!, A, alg.alg)
9494
end
9595

@@ -206,13 +206,13 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
206206
return S
207207
end
208208

209-
function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm)
209+
function svd_trunc_no_error!(A, USVᴴ, alg::TruncatedAlgorithm)
210210
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
211211
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
212212
return USVᴴtrunc
213213
end
214214

215-
function svd_trunc_with_err!(A, USVᴴ, alg::TruncatedAlgorithm)
215+
function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm)
216216
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
217217
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
218218
ϵ = truncation_error!(diagview(S), ind)
@@ -269,7 +269,7 @@ end
269269
###
270270

271271
function check_input(
272-
::Union{typeof(svd_trunc!), typeof(svd_trunc_with_err!)}, A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized
272+
::Union{typeof(svd_trunc!), typeof(svd_trunc_no_error!)}, A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized
273273
)
274274
m, n = size(A)
275275
minmn = min(m, n)
@@ -285,7 +285,7 @@ function check_input(
285285
end
286286

287287
function initialize_output(
288-
::Union{typeof(svd_trunc!), typeof(svd_trunc_with_err!)}, A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized}
288+
::Union{typeof(svd_trunc!), typeof(svd_trunc_no_error!)}, A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized}
289289
)
290290
m, n = size(A)
291291
minmn = min(m, n)
@@ -369,9 +369,9 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
369369
return USVᴴ
370370
end
371371

372-
function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized})
372+
function svd_trunc_no_error!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized})
373373
U, S, Vᴴ = USVᴴ
374-
check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg)
374+
check_input(svd_trunc_no_error!, A, (U, S, Vᴴ), alg.alg)
375375
_gpu_Xgesvdr!(A, diagview(S), U, Vᴴ; alg.alg.kwargs...)
376376

377377
# TODO: make sure that truncation is based on maxrank, otherwise this might be wrong
@@ -383,9 +383,9 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran
383383
return Utr, Str, Vᴴtr
384384
end
385385

386-
function svd_trunc_with_err!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized})
386+
function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized})
387387
U, S, Vᴴ = USVᴴ
388-
check_input(svd_trunc_with_err!, A, (U, S, Vᴴ), alg.alg)
388+
check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg)
389389
_gpu_Xgesvdr!(A, diagview(S), U, Vᴴ; alg.alg.kwargs...)
390390

391391
# TODO: make sure that truncation is based on maxrank, otherwise this might be wrong

src/interface/svd.jl

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_vals(!)`](@ref svd_vals) and
4242
@functiondef svd_compact
4343

4444
"""
45-
svd_trunc_with_err(A; [trunc], kwargs...) -> U, S, Vᴴ, ϵ
46-
svd_trunc_with_err(A, alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ
47-
svd_trunc_with_err!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ, ϵ
48-
svd_trunc_with_err!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ
45+
svd_trunc(A; [trunc], kwargs...) -> U, S, Vᴴ, ϵ
46+
svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ
47+
svd_trunc!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ, ϵ
48+
svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ
4949
5050
Compute a partial or truncated singular value decomposition (SVD) of `A`, such that
5151
`A * (Vᴴ)' ≈ U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size
@@ -86,22 +86,23 @@ truncation strategy is already embedded in the algorithm.
8686
possibly destroys the input matrix `A`. Always use the return value of the function
8787
as it may not always be possible to use the provided `USVᴴ` as output.
8888
89-
See also [`svd_trunc(!)`](@ref svd_trunc), [`svd_full(!)`](@ref svd_full),
89+
See also [`svd_trunc_no_error(!)`](@ref svd_trunc), [`svd_full(!)`](@ref svd_full),
9090
[`svd_compact(!)`](@ref svd_compact), [`svd_vals(!)`](@ref svd_vals),
9191
and [Truncations](@ref) for more information on truncation strategies.
9292
"""
93-
@functiondef svd_trunc_with_err
93+
@functiondef svd_trunc
9494

9595
"""
96-
svd_trunc(A; [trunc], kwargs...) -> U, S, Vᴴ
97-
svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ
98-
svd_trunc!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ
99-
svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ
96+
svd_trunc_no_error(A; [trunc], kwargs...) -> U, S, Vᴴ
97+
svd_trunc_no_error(A, alg::AbstractAlgorithm) -> U, S, Vᴴ
98+
svd_trunc_no_error!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ
99+
svd_trunc_no_error!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ
100100
101101
Compute a partial or truncated singular value decomposition (SVD) of `A`, such that
102102
`A * (Vᴴ)' ≈ U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size
103103
`(m, k)`, whereas `Vᴴ` is a matrix of size `(k, n)` with orthonormal rows and `S` is a
104104
square diagonal matrix of size `(k, k)`, with `k` is set by the truncation strategy.
105+
The truncation error is *not* returned.
105106
106107
## Truncation
107108
The truncation strategy can be controlled via the `trunc` keyword argument. This can be
@@ -130,15 +131,15 @@ When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be spec
130131
truncation strategy is already embedded in the algorithm.
131132
132133
!!! note
133-
The bang method `svd_trunc!` optionally accepts the output structure and
134+
The bang method `svd_trunc_no_error!` optionally accepts the output structure and
134135
possibly destroys the input matrix `A`. Always use the return value of the function
135136
as it may not always be possible to use the provided `USVᴴ` as output.
136137
137138
See also [`svd_full(!)`](@ref svd_full), [`svd_compact(!)`](@ref svd_compact),
138-
[`svd_vals(!)`](@ref svd_vals), and [Truncations](@ref) for more information on
139-
truncation strategies.
139+
[`svd_vals(!)`](@ref svd_vals), [`svd_trunc(!)`](@ref svd_trunc) and
140+
[Truncations](@ref) for more information on truncation strategies.
140141
"""
141-
@functiondef svd_trunc
142+
@functiondef svd_trunc_no_error
142143

143144
"""
144145
svd_vals(A; kwargs...) -> S
@@ -173,7 +174,7 @@ for f in (:svd_full!, :svd_compact!, :svd_vals!)
173174
end
174175
end
175176

176-
for f in (:svd_trunc!, :svd_trunc_with_err!)
177+
for f in (:svd_trunc!, :svd_trunc_no_error!)
177178
@eval function select_algorithm(::typeof($f), A, alg; trunc = nothing, kwargs...)
178179
if alg isa TruncatedAlgorithm
179180
isnothing(trunc) ||

test/amd/svd.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,14 @@ end
140140
# minmn = min(m, n)
141141
# r = minmn - 2
142142
#
143-
# U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc_with_err(A; alg, trunc=truncrank(r))
143+
# U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc=truncrank(r))
144144
# @test length(S1.diag) == r
145145
# @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1]
146146
#
147147
# s = 1 + sqrt(eps(real(T)))
148148
# trunc2 = trunctol(; atol=s * S₀[r + 1])
149149
#
150-
# U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc_with_err(A; alg, trunc=trunctol(; atol=s * S₀[r + 1]))
150+
# U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc=trunctol(; atol=s * S₀[r + 1]))
151151
# @test length(S2.diag) == r
152152
# @test U1 ≈ U2
153153
# @test S1 ≈ S2

test/chainrules.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ for f in
1212
(
1313
:qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null,
1414
:eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals,
15-
:svd_compact, :svd_trunc, :svd_trunc_with_err, :svd_vals,
15+
:svd_compact, :svd_trunc, :svd_trunc_no_error, :svd_vals,
1616
:left_polar, :right_polar,
1717
)
1818
copy_f = Symbol(:copy_, f)
@@ -430,12 +430,12 @@ end
430430
ΔUtrunc = ΔU[:, ind]
431431
ΔVᴴtrunc = ΔVᴴ[ind, :]
432432
test_rrule(
433-
copy_svd_trunc_with_err, A, truncalg NoTangent();
433+
copy_svd_trunc, A, truncalg NoTangent();
434434
output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))),
435435
atol = atol, rtol = rtol
436436
)
437437
test_rrule(
438-
copy_svd_trunc, A, truncalg NoTangent();
438+
copy_svd_trunc_no_error, A, truncalg NoTangent();
439439
output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc),
440440
atol = atol, rtol = rtol
441441
)
@@ -452,12 +452,12 @@ end
452452
ΔUtrunc = ΔU[:, ind]
453453
ΔVᴴtrunc = ΔVᴴ[ind, :]
454454
test_rrule(
455-
copy_svd_trunc_with_err, A, truncalg NoTangent();
455+
copy_svd_trunc, A, truncalg NoTangent();
456456
output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))),
457457
atol = atol, rtol = rtol
458458
)
459459
test_rrule(
460-
copy_svd_trunc, A, truncalg NoTangent();
460+
copy_svd_trunc_no_error, A, truncalg NoTangent();
461461
output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc),
462462
atol = atol, rtol = rtol
463463
)
@@ -485,13 +485,13 @@ end
485485
trunc = truncrank(r)
486486
ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc)
487487
test_rrule(
488-
config, svd_trunc_with_err, A;
488+
config, svd_trunc, A;
489489
fkwargs = (; trunc = trunc),
490490
output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))),
491491
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
492492
)
493493
test_rrule(
494-
config, svd_trunc, A;
494+
config, svd_trunc_no_error, A;
495495
fkwargs = (; trunc = trunc),
496496
output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :]),
497497
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
@@ -500,13 +500,13 @@ end
500500
trunc = trunctol(; atol = S[1, 1] / 2)
501501
ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc)
502502
test_rrule(
503-
config, svd_trunc_with_err, A;
503+
config, svd_trunc, A;
504504
fkwargs = (; trunc = trunc),
505505
output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))),
506506
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
507507
)
508508
test_rrule(
509-
config, svd_trunc, A;
509+
config, svd_trunc_no_error, A;
510510
fkwargs = (; trunc = trunc),
511511
output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :]),
512512
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false

0 commit comments

Comments
 (0)