Skip to content

Commit aec6c91

Browse files
committed
add truncerr
1 parent d082c7d commit aec6c91

17 files changed

Lines changed: 122 additions & 75 deletions

File tree

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module MatrixAlgebraKitChainRulesCoreExt
22

33
using MatrixAlgebraKit
44
using MatrixAlgebraKit: copy_input, initialize_output, zero!, diagview,
5-
TruncatedAlgorithm, findtruncated, findtruncated_svd
5+
TruncatedAlgorithm, findtruncated, findtruncated_svd, compute_truncerr!
66
using ChainRulesCore
77
using LinearAlgebra
88

@@ -113,15 +113,20 @@ for eig in (:eig, :eigh)
113113
Ac = copy_input($eig_f, A)
114114
DV = $(eig_f!)(Ac, DV, alg.alg)
115115
DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc)
116-
return DV′, $(_make_eig_t_pb)(A, DV, ind)
116+
ϵ = compute_truncerr!(diagview(copy(DV[1])), ind)
117+
return (DV′..., ϵ), $(_make_eig_t_pb)(A, DV, ind)
117118
end
118119
function $(_make_eig_t_pb)(A, DV, ind)
119-
function $eig_t_pb(ΔDV)
120+
function $eig_t_pb(ΔDVϵ)
120121
ΔA = zero(A)
121-
MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.(ΔDV), ind)
122+
ΔD, ΔV, Δϵ = ΔDVϵ
123+
if !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ))
124+
throw(ArgumentError("Pullback for eig_trunc! does not yet support non-zero tangent for the truncation error"))
125+
end
126+
MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.((ΔD, ΔV)), ind)
122127
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
123128
end
124-
function $eig_t_pb(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful?
129+
function $eig_t_pb(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
125130
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
126131
end
127132
return $eig_t_pb
@@ -152,15 +157,20 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg
152157
Ac = copy_input(svd_compact, A)
153158
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
154159
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
155-
return USVᴴ′, _make_svd_trunc_pullback(A, USVᴴ, ind)
160+
ϵ = compute_truncerr!(diagview(copy(USVᴴ[2])), ind)
161+
return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind)
156162
end
157163
function _make_svd_trunc_pullback(A, USVᴴ, ind)
158-
function svd_trunc_pullback(ΔUSVᴴ)
164+
function svd_trunc_pullback(ΔUSVᴴϵ)
159165
ΔA = zero(A)
160-
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.(ΔUSVᴴ), ind)
166+
ΔU, ΔS, ΔVᴴ, Δϵ = ΔUSVᴴϵ
167+
if !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ))
168+
throw(ArgumentError("Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"))
169+
end
170+
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind)
161171
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
162172
end
163-
function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
173+
function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
164174
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
165175
end
166176
return svd_trunc_pullback

src/implementations/eig.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ end
108108

109109
function eig_trunc!(A, DV, alg::TruncatedAlgorithm)
110110
D, V = eig_full!(A, DV, alg.alg)
111-
return first(truncate(eig_trunc!, (D, V), alg.trunc))
111+
DVtrunc, ind = truncate(eig_trunc!, (D, V), alg.trunc)
112+
return DVtrunc..., compute_truncerr!(diagview(D), ind)
112113
end
113114

114115
# Diagonal logic

src/implementations/eigh.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ end
111111

112112
function eigh_trunc!(A, DV, alg::TruncatedAlgorithm)
113113
D, V = eigh_full!(A, DV, alg.alg)
114-
return first(truncate(eigh_trunc!, (D, V), alg.trunc))
114+
DVtrunc, ind = truncate(eigh_trunc!, (D, V), alg.trunc)
115+
return DVtrunc..., compute_truncerr!(diagview(D), ind)
115116
end
116117

117118
# Diagonal logic

src/implementations/svd.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,9 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
237237
end
238238

239239
function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm)
240-
USVᴴ′ = svd_compact!(A, USVᴴ, alg.alg)
241-
return first(truncate(svd_trunc!, USVᴴ′, alg.trunc))
240+
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
241+
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
242+
return USVᴴtrunc..., compute_truncerr!(diagview(S), ind)
242243
end
243244

244245
# Diagonal logic

src/implementations/truncation.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,10 @@ end
116116
_ind_intersect(A::AbstractVector, B::AbstractVector{Bool}) = _ind_intersect(B, A)
117117
_ind_intersect(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .& B
118118
_ind_intersect(A, B) = intersect(A, B)
119+
120+
# Compute truncation error as 2-norm of discarded values
121+
# by destroying original values
122+
function compute_truncerr!(values::AbstractVector, ind)
123+
values[ind] .= zero(eltype(values))
124+
return norm(values)
125+
end

src/interface/eig.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,19 @@ See also [`eig_vals(!)`](@ref eig_vals) and [`eig_trunc(!)`](@ref eig_trunc).
3131
@functiondef eig_full
3232

3333
"""
34-
eig_trunc(A; kwargs...) -> D, V
35-
eig_trunc(A, alg::AbstractAlgorithm) -> D, V
36-
eig_trunc!(A, [DV]; kwargs...) -> D, V
37-
eig_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V
34+
eig_trunc(A; kwargs...) -> D, V, ϵ
35+
eig_trunc(A, alg::AbstractAlgorithm) -> D, V, ϵ
36+
eig_trunc!(A, [DV]; kwargs...) -> D, V, ϵ
37+
eig_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V, ϵ
3838
3939
Compute a partial or truncated eigenvalue decomposition of the matrix `A`,
4040
such that `A * V ≈ V * D`, where the (possibly rectangular) matrix `V` contains
4141
a subset of eigenvectors and the diagonal matrix `D` contains the associated eigenvalues,
4242
selected according to a truncation strategy.
4343
44+
The function also returns `ϵ`, the truncation error defined as the 2-norm of the
45+
discarded eigenvalues.
46+
4447
!!! note
4548
The bang method `eig_trunc!` optionally accepts the output structure and
4649
possibly destroys the input matrix `A`. Always use the return value of the function

src/interface/eigh.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,18 @@ For generic eigenvalue decompositions, see [`eig_full`](@ref).
99
"""
1010

1111
"""
12-
eigh_full(A; kwargs...) -> D, V
13-
eigh_full(A, alg::AbstractAlgorithm) -> D, V
14-
eigh_full!(A, [DV]; kwargs...) -> D, V
15-
eigh_full!(A, [DV], alg::AbstractAlgorithm) -> D, V
12+
eigh_full(A; kwargs...) -> D, V, ϵ
13+
eigh_full(A, alg::AbstractAlgorithm) -> D, V, ϵ
14+
eigh_full!(A, [DV]; kwargs...) -> D, V, ϵ
15+
eigh_full!(A, [DV], alg::AbstractAlgorithm) -> D, V, ϵ
1616
1717
Compute the full eigenvalue decomposition of the symmetric or hermitian matrix `A`,
1818
such that `A * V = V * D`, where the unitary matrix `V` contains the orthogonal eigenvectors
1919
and the real diagonal matrix `D` contains the associated eigenvalues.
2020
21+
The function also returns `ϵ`, the truncation error defined as the 2-norm of the
22+
discarded eigenvalues.
23+
2124
!!! note
2225
The bang method `eigh_full!` optionally accepts the output structure and
2326
possibly destroys the input matrix `A`. Always use the return value of the function

src/interface/svd.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,19 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_vals(!)`](@ref svd_vals) and
4343

4444
# TODO: decide if we should have `svd_trunc!!` instead
4545
"""
46-
svd_trunc(A; kwargs...) -> U, S, Vᴴ
47-
svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ
48-
svd_trunc!(A, [USVᴴ]; kwargs...) -> U, S, Vᴴ
49-
svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ
46+
svd_trunc(A; kwargs...) -> U, S, Vᴴ, ϵ
47+
svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ
48+
svd_trunc!(A, [USVᴴ]; kwargs...) -> U, S, Vᴴ, ϵ
49+
svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ
5050
5151
Compute a partial or truncated singular value decomposition (SVD) of `A`, such that
52-
`A * (Vᴴ)' = U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size
52+
`A * (Vᴴ)' U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size
5353
`(m, k)`, whereas `Vᴴ` is a matrix of size `(k, n)` with orthonormal rows and `S` is a
5454
square diagonal matrix of size `(k, k)`, with `k` is set by the truncation strategy.
5555
56+
The function also returns `ϵ`, the truncation error defined as the 2-norm of the
57+
discarded singular values.
58+
5659
!!! note
5760
The bang method `svd_trunc!` optionally accepts the output structure and
5861
possibly destroys the input matrix `A`. Always use the return value of the function

test/amd/eigh.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,14 @@ end
4646
r = m - 2
4747
s = 1 + sqrt(eps(real(T)))
4848
49-
D1, V1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r))
49+
D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r))
5050
@test length(diagview(D1)) == r
5151
@test isisometry(V1)
5252
@test A * V1 ≈ V1 * D1
5353
@test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1]
5454
5555
trunc = trunctol(; atol=s * D₀[r + 1])
56-
D2, V2 = @constinferred eigh_trunc(A; alg, trunc)
56+
D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc)
5757
@test length(diagview(D2)) == r
5858
@test isisometry(V2)
5959
@test A * V2 ≈ V2 * D2
@@ -75,7 +75,7 @@ end
7575
A = V * D * V'
7676
A = (A + A') / 2
7777
alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2))
78-
D2, V2 = @constinferred eigh_trunc(A; alg)
78+
D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg)
7979
@test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T)))
8080
@test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2))
8181
end=#

test/amd/svd.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,23 +94,23 @@ end
9494
# algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(),
9595
# LAPACK_Jacobi())
9696
# end
97-
97+
#
9898
# @testset "size ($m, $n)" for n in (37, m, 63)
9999
# @testset "algorithm $alg" for alg in algs
100100
# n > m && alg isa LAPACK_Jacobi && continue # not supported
101101
# A = randn(rng, T, m, n)
102102
# S₀ = svd_vals(A)
103103
# minmn = min(m, n)
104104
# r = minmn - 2
105-
106-
# U1, S1, V1ᴴ = @constinferred svd_trunc(A; alg, trunc=truncrank(r))
105+
#
106+
# U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc=truncrank(r))
107107
# @test length(S1.diag) == r
108108
# @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1]
109-
109+
#
110110
# s = 1 + sqrt(eps(real(T)))
111111
# trunc2 = trunctol(; atol=s * S₀[r + 1])
112-
113-
# U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc=trunctol(; atol=s * S₀[r + 1]))
112+
#
113+
# U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc=trunctol(; atol=s * S₀[r + 1]))
114114
# @test length(S2.diag) == r
115115
# @test U1 ≈ U2
116116
# @test S1 ≈ S2

0 commit comments

Comments
 (0)