Skip to content

Commit ba9867b

Browse files
Jutholkdvos
andauthored
Output truncation error for truncated decompositions (#75)
* add truncerr * some final fixes * fix cuda test * make `trunc` kwarg explicit in docstrings * update docs * bump JET compat * rename `truncation_error(!)` * docstring fixes * relax truncation error tolerances * some test streamline --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent b825314 commit ba9867b

20 files changed

Lines changed: 209 additions & 123 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Aqua = "0.6, 0.7, 0.8"
2222
ChainRulesCore = "1"
2323
ChainRulesTestUtils = "1"
2424
CUDA = "5"
25-
JET = "0.9"
25+
JET = "0.9, 0.10"
2626
LinearAlgebra = "1"
2727
SafeTestsets = "0.1"
2828
StableRNGs = "1"

docs/src/user_interface/truncations.md

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,15 @@ Truncation strategies allow you to control which eigenvalues or singular values
1212
Truncation strategies can be used with truncated decomposition functions in two ways, as illustrated below.
1313
For concreteness, we use the following matrix as an example:
1414

15-
```jldoctest truncations
15+
```jldoctest truncations; output=false
1616
using MatrixAlgebraKit
1717
using MatrixAlgebraKit: diagview
1818
1919
A = [2 1 0; 1 3 1; 0 1 4];
2020
D, V = eigh_full(A);
21-
2221
diagview(D) ≈ [3 - √3, 3, 3 + √3]
2322
2423
# output
25-
2624
true
2725
```
2826

@@ -31,38 +29,35 @@ true
3129
The simplest approach is to pass a `NamedTuple` with the truncation parameters.
3230
For example, keeping only the largest 2 eigenvalues:
3331

34-
```jldoctest truncations
35-
Dtrunc, Vtrunc = eigh_trunc(A; trunc = (maxrank = 2,));
32+
```jldoctest truncations; output=false
33+
Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = (maxrank = 2,));
3634
size(Dtrunc, 1) <= 2
3735
3836
# output
39-
4037
true
4138
```
4239

4340
Note however that there are no guarantees on the order of the output values:
4441

45-
```jldoctest truncations
42+
```jldoctest truncations; output=false
4643
diagview(Dtrunc) ≈ diagview(D)[[3, 2]]
4744
4845
# output
49-
5046
true
5147
```
5248

5349
You can also use tolerance-based truncation or combine multiple criteria:
5450

55-
```jldoctest truncations
56-
Dtrunc, Vtrunc = eigh_trunc(A; trunc = (atol = 2.9,));
51+
```jldoctest truncations; output=false
52+
Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = (atol = 2.9,));
5753
all(>(2.9), diagview(Dtrunc))
5854
5955
# output
60-
6156
true
6257
```
6358

64-
```jldoctest truncations
65-
Dtrunc, Vtrunc = eigh_trunc(A; trunc = (maxrank = 2, atol = 2.9));
59+
```jldoctest truncations; output=false
60+
Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = (maxrank = 2, atol = 2.9));
6661
size(Dtrunc, 1) <= 2 && all(>(2.9), diagview(Dtrunc))
6762
6863
# output
@@ -72,7 +67,7 @@ true
7267
In general, the keyword arguments that are supported can be found in the `TruncationStrategy` docstring:
7368

7469
```@docs; canonical = false
75-
TruncationStrategy
70+
TruncationStrategy()
7671
```
7772

7873

@@ -81,33 +76,22 @@ TruncationStrategy
8176
For more control, you can construct [`TruncationStrategy`](@ref) objects directly.
8277
This is also what the previous syntax will end up calling.
8378

84-
```jldoctest truncations
79+
```jldoctest truncations; output=false
8580
Dtrunc, Vtrunc = eigh_trunc(A; trunc = truncrank(2))
8681
size(Dtrunc, 1) <= 2
8782
8883
# output
89-
9084
true
9185
```
9286

93-
```jldoctest truncations
94-
Dtrunc, Vtrunc = eigh_trunc(A; trunc = truncrank(2) & trunctol(; atol = 2.9))
87+
```jldoctest truncations; output=false
88+
Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = truncrank(2) & trunctol(; atol = 2.9))
9589
size(Dtrunc, 1) <= 2 && all(>(2.9), diagview(Dtrunc))
9690
9791
# output
9892
true
9993
```
10094

101-
## Truncation with SVD vs Eigenvalue Decompositions
102-
103-
When using truncations with different decomposition types, keep in mind:
104-
105-
- **`svd_trunc`**: Singular values are always real and non-negative, sorted in descending order. Truncation by value typically keeps the largest singular values.
106-
107-
- **`eigh_trunc`**: Eigenvalues are real but can be negative for symmetric matrices. By default, `truncrank` sorts by absolute value, so `truncrank(k)` keeps the `k` eigenvalues with largest magnitude (positive or negative).
108-
109-
- **`eig_trunc`**: For general (non-symmetric) matrices, eigenvalues can be complex. Truncation by absolute value considers the complex magnitude.
110-
11195
## Truncation Strategies
11296

11397
MatrixAlgebraKit provides several built-in truncation strategies:
@@ -127,3 +111,31 @@ When strategies are combined, only the values that satisfy all conditions are ke
127111
combined_trunc = truncrank(10) & trunctol(; atol = 1e-6);
128112
```
129113

114+
## Truncation Error
115+
116+
When using truncated decompositions such as [`svd_trunc`](@ref), [`eig_trunc`](@ref), or [`eigh_trunc`](@ref), an additional truncation error value is returned.
117+
This error is defined as the 2-norm of the discarded singular values or eigenvalues, providing a measure of the approximation quality.
118+
For `svd_trunc` and `eigh_trunc`, this corresponds to the 2-norm difference between the original and the truncated matrix.
119+
For the case of `eig_trunc`, this interpretation does not hold because the norm of the non-unitary matrix of eigenvectors and its inverse also influence the approximation quality.
120+
121+
122+
For example:
123+
```jldoctest truncations; output=false
124+
using LinearAlgebra: norm
125+
U, S, Vᴴ, ϵ = svd_trunc(A; trunc=truncrank(2))
126+
norm(A - U * S * Vᴴ) ≈ ϵ # ϵ is the 2-norm of the discarded singular values
127+
128+
# output
129+
true
130+
```
131+
132+
### Truncation with SVD vs Eigenvalue Decompositions
133+
134+
When using truncations with different decomposition types, keep in mind:
135+
136+
- **[`svd_trunc`](@ref)**: Singular values are always real and non-negative, sorted in descending order. Truncation by value typically keeps the largest singular values. The truncation error gives the 2-norm difference between the original and the truncated matrix.
137+
138+
- **[`eigh_trunc`](@ref)**: Eigenvalues are real but can be negative for symmetric matrices. By default, eigenvalues are treated by absolute value, e.g. `truncrank(k)` keeps the `k` eigenvalues with largest magnitude (positive or negative). The truncation error gives the 2-norm difference between the original and the truncated matrix.
139+
140+
- **[`eig_trunc`](@ref)**: For general (non-symmetric) matrices, eigenvalues can be complex. By default, eigenvalues are treated by absolute value. The truncation error gives an indication of the magnitude of discarded values, but is not directly related to the 2-norm difference between the original and the truncated matrix.
141+

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module MatrixAlgebraKitChainRulesCoreExt
22

33
using MatrixAlgebraKit
44
using MatrixAlgebraKit: copy_input, initialize_output, zero!, diagview,
5-
TruncatedAlgorithm, findtruncated, findtruncated_svd
5+
TruncatedAlgorithm, findtruncated, findtruncated_svd, truncation_error
66
using ChainRulesCore
77
using LinearAlgebra
88

@@ -113,15 +113,20 @@ for eig in (:eig, :eigh)
113113
Ac = copy_input($eig_f, A)
114114
DV = $(eig_f!)(Ac, DV, alg.alg)
115115
DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc)
116-
return DV′, $(_make_eig_t_pb)(A, DV, ind)
116+
ϵ = truncation_error(diagview(DV[1]), ind)
117+
return (DV′..., ϵ), $(_make_eig_t_pb)(A, DV, ind)
117118
end
118119
function $(_make_eig_t_pb)(A, DV, ind)
119-
function $eig_t_pb(ΔDV)
120+
function $eig_t_pb(ΔDVϵ)
120121
ΔA = zero(A)
121-
MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.(ΔDV), ind)
122+
ΔD, ΔV, Δϵ = ΔDVϵ
123+
if !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ))
124+
throw(ArgumentError("Pullback for eig_trunc! does not yet support non-zero tangent for the truncation error"))
125+
end
126+
MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.((ΔD, ΔV)), ind)
122127
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
123128
end
124-
function $eig_t_pb(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful?
129+
function $eig_t_pb(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
125130
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
126131
end
127132
return $eig_t_pb
@@ -152,15 +157,20 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg
152157
Ac = copy_input(svd_compact, A)
153158
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
154159
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
155-
return USVᴴ′, _make_svd_trunc_pullback(A, USVᴴ, ind)
160+
ϵ = truncation_error(diagview(USVᴴ[2]), ind)
161+
return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind)
156162
end
157163
function _make_svd_trunc_pullback(A, USVᴴ, ind)
158-
function svd_trunc_pullback(ΔUSVᴴ)
164+
function svd_trunc_pullback(ΔUSVᴴϵ)
159165
ΔA = zero(A)
160-
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.(ΔUSVᴴ), ind)
166+
ΔU, ΔS, ΔVᴴ, Δϵ = ΔUSVᴴϵ
167+
if !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ))
168+
throw(ArgumentError("Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"))
169+
end
170+
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind)
161171
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
162172
end
163-
function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
173+
function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
164174
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
165175
end
166176
return svd_trunc_pullback

src/implementations/eig.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ end
108108

109109
function eig_trunc!(A, DV, alg::TruncatedAlgorithm)
110110
D, V = eig_full!(A, DV, alg.alg)
111-
return first(truncate(eig_trunc!, (D, V), alg.trunc))
111+
DVtrunc, ind = truncate(eig_trunc!, (D, V), alg.trunc)
112+
return DVtrunc..., truncation_error!(diagview(D), ind)
112113
end
113114

114115
# Diagonal logic

src/implementations/eigh.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ end
111111

112112
function eigh_trunc!(A, DV, alg::TruncatedAlgorithm)
113113
D, V = eigh_full!(A, DV, alg.alg)
114-
return first(truncate(eigh_trunc!, (D, V), alg.trunc))
114+
DVtrunc, ind = truncate(eigh_trunc!, (D, V), alg.trunc)
115+
return DVtrunc..., truncation_error!(diagview(D), ind)
115116
end
116117

117118
# Diagonal logic

src/implementations/svd.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,9 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
237237
end
238238

239239
function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm)
240-
USVᴴ′ = svd_compact!(A, USVᴴ, alg.alg)
241-
return first(truncate(svd_trunc!, USVᴴ′, alg.trunc))
240+
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
241+
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
242+
return USVᴴtrunc..., truncation_error!(diagview(S), ind)
242243
end
243244

244245
# Diagonal logic
@@ -381,7 +382,12 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran
381382
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...)
382383
# TODO: make this controllable using a `gaugefix` keyword argument
383384
gaugefix!(svd_trunc!, U, S, Vᴴ, size(A)...)
384-
return first(truncate(svd_trunc!, USVᴴ, alg.trunc))
385+
# TODO: make sure that truncation is based on maxrank, otherwise this might be wrong
386+
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
387+
Strunc = diagview(USVᴴtrunc[2])
388+
# normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
389+
ϵ = sqrt(norm(A)^2 - norm(Strunc)^2) # is there a more accurate way to do this?
390+
return USVᴴtrunc..., ϵ
385391
end
386392

387393
function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)

src/implementations/truncation.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,13 @@ end
116116
_ind_intersect(A::AbstractVector, B::AbstractVector{Bool}) = _ind_intersect(B, A)
117117
_ind_intersect(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .& B
118118
_ind_intersect(A, B) = intersect(A, B)
119+
120+
# Truncation error
121+
# ----------------
122+
truncation_error(values::AbstractVector, ind) = truncation_error!(copy(values), ind)
123+
# destroys input in order to maximize accuracy:
124+
# sqrt(norm(values)^2 - norm(values[ind])^2) might suffer from floating point error
125+
function truncation_error!(values::AbstractVector, ind)
126+
values[ind] .= zero(eltype(values))
127+
return norm(values)
128+
end

src/interface/eig.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,19 @@ See also [`eig_vals(!)`](@ref eig_vals) and [`eig_trunc(!)`](@ref eig_trunc).
3232
@functiondef eig_full
3333

3434
"""
35-
eig_trunc(A; [trunc], kwargs...) -> D, V
36-
eig_trunc(A, alg::AbstractAlgorithm) -> D, V
37-
eig_trunc!(A, [DV]; [trunc], kwargs...) -> D, V
38-
eig_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V
35+
eig_trunc(A; [trunc], kwargs...) -> D, V, ϵ
36+
eig_trunc(A, alg::AbstractAlgorithm) -> D, V, ϵ
37+
eig_trunc!(A, [DV]; [trunc], kwargs...) -> D, V, ϵ
38+
eig_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V, ϵ
3939
4040
Compute a partial or truncated eigenvalue decomposition of the matrix `A`,
4141
such that `A * V ≈ V * D`, where the (possibly rectangular) matrix `V` contains
4242
a subset of eigenvectors and the diagonal matrix `D` contains the associated eigenvalues,
4343
selected according to a truncation strategy.
4444
45+
The function also returns `ϵ`, the truncation error defined as the 2-norm of the
46+
discarded eigenvalues.
47+
4548
## Keyword arguments
4649
The behavior of this function is controlled by the following keyword arguments:
4750

src/interface/eigh.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,18 @@ docs_eigh_note = """
1212
"""
1313

1414
"""
15-
eigh_full(A; kwargs...) -> D, V
16-
eigh_full(A, alg::AbstractAlgorithm) -> D, V
17-
eigh_full!(A, [DV]; kwargs...) -> D, V
18-
eigh_full!(A, [DV], alg::AbstractAlgorithm) -> D, V
15+
eigh_full(A; kwargs...) -> D, V, ϵ
16+
eigh_full(A, alg::AbstractAlgorithm) -> D, V, ϵ
17+
eigh_full!(A, [DV]; kwargs...) -> D, V, ϵ
18+
eigh_full!(A, [DV], alg::AbstractAlgorithm) -> D, V, ϵ
1919
2020
Compute the full eigenvalue decomposition of the symmetric or hermitian matrix `A`,
2121
such that `A * V = V * D`, where the unitary matrix `V` contains the orthogonal eigenvectors
2222
and the real diagonal matrix `D` contains the associated eigenvalues.
2323
24+
The function also returns `ϵ`, the truncation error defined as the 2-norm of the
25+
discarded eigenvalues.
26+
2427
!!! note
2528
The bang method `eigh_full!` optionally accepts the output structure and
2629
possibly destroys the input matrix `A`. Always use the return value of the function
@@ -34,16 +37,19 @@ See also [`eigh_vals(!)`](@ref eigh_vals) and [`eigh_trunc(!)`](@ref eigh_trunc)
3437
@functiondef eigh_full
3538

3639
"""
37-
eigh_trunc(A; [trunc], kwargs...) -> D, V
38-
eigh_trunc(A, alg::AbstractAlgorithm) -> D, V
39-
eigh_trunc!(A, [DV]; [trunc], kwargs...) -> D, V
40-
eigh_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V
40+
eigh_trunc(A; [trunc], kwargs...) -> D, V, ϵ
41+
eigh_trunc(A, alg::AbstractAlgorithm) -> D, V, ϵ
42+
eigh_trunc!(A, [DV]; [trunc], kwargs...) -> D, V, ϵ
43+
eigh_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V, ϵ
4144
4245
Compute a partial or truncated eigenvalue decomposition of the symmetric or hermitian matrix
4346
`A`, such that `A * V ≈ V * D`, where the isometric matrix `V` contains a subset of the
4447
orthogonal eigenvectors and the real diagonal matrix `D` contains the associated eigenvalues,
4548
selected according to a truncation strategy.
4649
50+
The function also returns `ϵ`, the truncation error defined as the 2-norm of the discarded
51+
eigenvalues.
52+
4753
## Keyword arguments
4854
The behavior of this function is controlled by the following keyword arguments:
4955

src/interface/svd.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,19 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_vals(!)`](@ref svd_vals) and
4242
@functiondef svd_compact
4343

4444
"""
45-
svd_trunc(A; [trunc], kwargs...) -> U, S, Vᴴ
46-
svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ
47-
svd_trunc!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ
48-
svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ
45+
svd_trunc(A; [trunc], kwargs...) -> U, S, Vᴴ, ϵ
46+
svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ
47+
svd_trunc!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ, ϵ
48+
svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ
4949
5050
Compute a partial or truncated singular value decomposition (SVD) of `A`, such that
51-
`A * (Vᴴ)' = U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size
51+
`A * (Vᴴ)' U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size
5252
`(m, k)`, whereas `Vᴴ` is a matrix of size `(k, n)` with orthonormal rows and `S` is a
5353
square diagonal matrix of size `(k, k)`, with `k` is set by the truncation strategy.
5454
55+
The function also returns `ϵ`, the truncation error defined as the 2-norm of the
56+
discarded singular values.
57+
5558
## Keyword arguments
5659
The behavior of this function is controlled by the following keyword arguments:
5760

0 commit comments

Comments
 (0)