Skip to content

Commit f1147f5

Browse files
committed
rename truncation_error(!)
1 parent 3a0c544 commit f1147f5

5 files changed

Lines changed: 21 additions & 10 deletions

File tree

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module MatrixAlgebraKitChainRulesCoreExt
22

33
using MatrixAlgebraKit
44
using MatrixAlgebraKit: copy_input, initialize_output, zero!, diagview,
5-
TruncatedAlgorithm, findtruncated, findtruncated_svd, compute_truncerr!
5+
TruncatedAlgorithm, findtruncated, findtruncated_svd, truncation_error
66
using ChainRulesCore
77
using LinearAlgebra
88

@@ -113,7 +113,7 @@ for eig in (:eig, :eigh)
113113
Ac = copy_input($eig_f, A)
114114
DV = $(eig_f!)(Ac, DV, alg.alg)
115115
DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc)
116-
ϵ = compute_truncerr!(copy(diagview(DV[1])), ind)
116+
ϵ = truncation_error(diagview(DV[1]), ind)
117117
return (DV′..., ϵ), $(_make_eig_t_pb)(A, DV, ind)
118118
end
119119
function $(_make_eig_t_pb)(A, DV, ind)
@@ -157,7 +157,7 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg
157157
Ac = copy_input(svd_compact, A)
158158
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
159159
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
160-
ϵ = compute_truncerr!(copy(diagview(USVᴴ[2])), ind)
160+
ϵ = truncation_error(diagview(USVᴴ[2]), ind)
161161
return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind)
162162
end
163163
function _make_svd_trunc_pullback(A, USVᴴ, ind)

src/implementations/eig.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ end
109109
function eig_trunc!(A, DV, alg::TruncatedAlgorithm)
110110
D, V = eig_full!(A, DV, alg.alg)
111111
DVtrunc, ind = truncate(eig_trunc!, (D, V), alg.trunc)
112-
return DVtrunc..., compute_truncerr!(diagview(D), ind)
112+
return DVtrunc..., truncation_error!(diagview(D), ind)
113113
end
114114

115115
# Diagonal logic

src/implementations/eigh.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ end
112112
function eigh_trunc!(A, DV, alg::TruncatedAlgorithm)
113113
D, V = eigh_full!(A, DV, alg.alg)
114114
DVtrunc, ind = truncate(eigh_trunc!, (D, V), alg.trunc)
115-
return DVtrunc..., compute_truncerr!(diagview(D), ind)
115+
return DVtrunc..., truncation_error!(diagview(D), ind)
116116
end
117117

118118
# Diagonal logic

src/implementations/svd.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ end
239239
function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm)
240240
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
241241
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
242-
return USVᴴtrunc..., compute_truncerr!(diagview(S), ind)
242+
return USVᴴtrunc..., truncation_error!(diagview(S), ind)
243243
end
244244

245245
# Diagonal logic
@@ -385,7 +385,7 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran
385385
# TODO: make sure that truncation is based on maxrank, otherwise this might be wrong
386386
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
387387
Strunc = diagview(USVᴴtrunc[2])
388-
# normal `compute_truncerr!` does not work here since `S` is not the full singular value spectrum
388+
# normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
389389
ϵ = sqrt(norm(A)^2 - norm(Strunc)^2) # is there a more accurate way to do this?
390390
return USVᴴtrunc..., ϵ
391391
end

src/implementations/truncation.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,20 @@ _ind_intersect(A::AbstractVector, B::AbstractVector{Bool}) = _ind_intersect(B, A
117117
_ind_intersect(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .& B
118118
_ind_intersect(A, B) = intersect(A, B)
119119

120-
# Compute truncation error as 2-norm of discarded values
121-
# by destroying original values
122-
function compute_truncerr!(values::AbstractVector, ind)
120+
# Truncation error
121+
# ----------------
122+
@doc """
123+
truncation_error(values, ind)
124+
truncation_error!(values, ind)
125+
126+
Determine the truncation error of selecting `ind` out of the `values`.
127+
This is defined as the 2-norm of the discarded values.
128+
""" truncation_error, truncation_error!
129+
130+
truncation_error(values::AbstractVector, ind) = truncation_error!(copy(values), ind)
131+
# destroys input in order to maximize accuracy:
132+
# sqrt(norm(values)^2 - norm(values[ind])^2) might suffer from floating point error
133+
function truncation_error!(values::AbstractVector, ind)
123134
values[ind] .= zero(eltype(values))
124135
return norm(values)
125136
end

0 commit comments

Comments
 (0)