Skip to content

Commit 48d9803

Browse files
committed
use truncation_error everywhere
1 parent e0198ef commit 48d9803

7 files changed

Lines changed: 24 additions & 22 deletions

File tree

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module MatrixAlgebraKitChainRulesCoreExt
22

33
using MatrixAlgebraKit
44
using MatrixAlgebraKit: copy_input, initialize_output, zero!, diagview,
5-
TruncatedAlgorithm, findtruncated, findtruncated_svd, compute_truncerr!
5+
TruncatedAlgorithm, findtruncated, findtruncated_svd, truncation_error
66
using ChainRulesCore
77
using LinearAlgebra
88

@@ -113,7 +113,7 @@ for eig in (:eig, :eigh)
113113
Ac = copy_input($eig_f, A)
114114
DV = $(eig_f!)(Ac, DV, alg.alg)
115115
DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc)
116-
ϵ = compute_truncerr!(copy(diagview(DV[1])), ind)
116+
ϵ = truncation_error(diagview(DV[1]), ind)
117117
return (DV′..., ϵ), $(_make_eig_t_pb)(A, DV, ind)
118118
end
119119
function $(_make_eig_t_pb)(A, DV, ind)
@@ -157,7 +157,7 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg
157157
Ac = copy_input(svd_compact, A)
158158
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
159159
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
160-
ϵ = compute_truncerr!(copy(diagview(USVᴴ[2])), ind)
160+
ϵ = truncation_error(diagview(USVᴴ[2]), ind)
161161
return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind)
162162
end
163163
function _make_svd_trunc_pullback(A, USVᴴ, ind)

src/MatrixAlgebraKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ export notrunc, truncrank, trunctol, truncerror, truncfilter
5454
eval(
5555
Expr(
5656
:public, :TruncationByOrder, :TruncationByFilter, :TruncationByValue,
57-
:TruncationByError, :TruncationIntersection, :truncate
57+
:TruncationByError, :TruncationIntersection, :truncate, :truncation_error
5858
)
5959
)
6060
eval(

src/implementations/eig.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ end
109109
function eig_trunc!(A, DV, alg::TruncatedAlgorithm)
110110
D, V = eig_full!(A, DV, alg.alg)
111111
DVtrunc, ind = truncate(eig_trunc!, (D, V), alg.trunc)
112-
return DVtrunc..., compute_truncerr!(diagview(D), ind)
112+
ϵ = truncation_error(diagview(D), ind)
113+
return DVtrunc..., ϵ
113114
end
114115

115116
# Diagonal logic

src/implementations/eigh.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ end
112112
function eigh_trunc!(A, DV, alg::TruncatedAlgorithm)
113113
D, V = eigh_full!(A, DV, alg.alg)
114114
DVtrunc, ind = truncate(eigh_trunc!, (D, V), alg.trunc)
115-
return DVtrunc..., compute_truncerr!(diagview(D), ind)
115+
ϵ = truncation_error(diagview(D), ind)
116+
return DVtrunc..., ϵ
116117
end
117118

118119
# Diagonal logic

src/implementations/svd.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,8 @@ end
239239
function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm)
240240
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
241241
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
242-
return USVᴴtrunc..., compute_truncerr!(diagview(S), ind)
242+
ϵ = truncation_error(diagview(S), ind)
243+
return USVᴴtrunc..., ϵ
243244
end
244245

245246
# Diagonal logic
@@ -385,7 +386,7 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran
385386
# TODO: make sure that truncation is based on maxrank, otherwise this might be wrong
386387
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
387388
Strunc = diagview(USVᴴtrunc[2])
388-
# normal `compute_truncerr!` does not work here since `S` is not the full singular value spectrum
389+
# normal `truncation_error` does not work here since `S` is not the full singular value spectrum
389390
ϵ = sqrt(norm(A)^2 - norm(Strunc)^2) # is there a more accurate way to do this?
390391
return USVᴴtrunc..., ϵ
391392
end

src/implementations/truncation.jl

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -117,30 +117,26 @@ _ind_intersect(A::AbstractVector, B::AbstractVector{Bool}) = _ind_intersect(B, A
117117
_ind_intersect(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .& B
118118
_ind_intersect(A, B) = intersect(A, B)
119119

120-
# Compute truncation error as 2-norm of discarded values
121-
# by destroying original values
122-
function compute_truncerr!(values::AbstractVector, ind)
123-
values[ind] .= zero(eltype(values))
124-
return norm(values)
125-
end
120+
# Truncation error
121+
# ----------------
122+
# generic fallback: no allocations but inaccurate
123+
truncation_error(values::AbstractVector, ind) = sqrt(norm(values)^2 - norm(view(values, ind))^2)
126124

127-
function compute_truncerr(values::AbstractVector{<:Number}, ind::AbstractUnitRange)
125+
function truncation_error(values::AbstractVector{<:Number}, ind::AbstractUnitRange)
128126
init = abs2(zero(eltype(values)))
129127
return sqrt(
130128
sum(abs2, view(values, firstindex(values):(first(ind) - 1)); init) +
131129
sum(abs2, view(values, (last(ind) + 1):lastindex(values)); init)
132130
)
133131
end
134-
135-
function compute_truncerr(values::AbstractVector{<:Number}, ind::AbstractVector{Bool})
132+
function truncation_error(values::AbstractVector{<:Number}, ind::AbstractVector{Bool})
136133
init = abs2(zero(eltype(values)))
137134
@inbounds for i in eachindex(values, ind)
138135
init += abs2(values[i] * ~(ind[i]))
139136
end
140137
return sqrt(init)
141138
end
142-
143-
function compute_truncerr(values::AbstractVector{<:Number}, ind::AbstractVector{<:Integer})
139+
function truncation_error(values::AbstractVector{<:Number}, ind::AbstractVector{<:Integer})
144140
sort!(ind)
145141
allind = eachindex(IndexLinear(), values)
146142
next_i, next_j = iterate(allind), iterate(ind)
@@ -161,6 +157,3 @@ function compute_truncerr(values::AbstractVector{<:Number}, ind::AbstractVector{
161157

162158
return sqrt(init)
163159
end
164-
165-
# generic fallback: no allocations but inaccurate
166-
compute_truncerr(values::AbstractVector, ind) = sqrt(norm(values)^2 - norm(view(values, ind))^2)

src/interface/truncation.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,9 @@ Base.:&(::NoTruncation, ::NoTruncation) = notrunc()
181181
# disambiguate
182182
Base.:&(::NoTruncation, trunc::TruncationIntersection) = trunc
183183
Base.:&(trunc::TruncationIntersection, ::NoTruncation) = trunc
184+
185+
@doc """
186+
truncation_error(values, ind)
187+
188+
Compute the truncation error as the 2-norm of the values that are not kept by `ind`.
189+
""" truncation_error

0 commit comments

Comments
 (0)