-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathtruncation.jl
More file actions
116 lines (105 loc) · 4.56 KB
/
truncation.jl
File metadata and controls
116 lines (105 loc) · 4.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# truncate!
# ---------
# Generic implementation: `findtruncated` followed by indexing
function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy)
ind = findtruncated_svd(diagview(S), strategy)
return U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]
end
function truncate!(::typeof(eig_trunc!), (D, V), strategy::TruncationStrategy)
ind = findtruncated(diagview(D), strategy)
return Diagonal(diagview(D)[ind]), V[:, ind]
end
function truncate!(::typeof(eigh_trunc!), (D, V), strategy::TruncationStrategy)
ind = findtruncated(diagview(D), strategy)
return Diagonal(diagview(D)[ind]), V[:, ind]
end
function truncate!(::typeof(left_null!), (U, S), strategy::TruncationStrategy)
# TODO: avoid allocation?
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2))))
ind = findtruncated(extended_S, strategy)
return U[:, ind]
end
function truncate!(::typeof(right_null!), (S, Vᴴ), strategy::TruncationStrategy)
# TODO: avoid allocation?
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 2) - size(S, 1))))
ind = findtruncated(extended_S, strategy)
return Vᴴ[ind, :]
end
# findtruncated
# -------------
# Generic fallback
function findtruncated_svd(values::AbstractVector, strategy::TruncationStrategy)
return findtruncated(values, strategy)
end
# specific implementations for finding truncated values
findtruncated(values::AbstractVector, ::NoTruncation) = Colon()
function findtruncated(values::AbstractVector, strategy::TruncationByOrder)
howmany = min(strategy.howmany, length(values))
return partialsortperm(values, 1:howmany; strategy.by, strategy.rev)
end
function findtruncated_svd(values::AbstractVector, strategy::TruncationByOrder)
strategy.by === abs || return findtruncated(values, strategy)
howmany = min(strategy.howmany, length(values))
return strategy.rev ? (1:howmany) : ((length(values) - howmany + 1):length(values))
end
function findtruncated(values::AbstractVector, strategy::TruncationByFilter)
return strategy.filter.(values)::AbstractVector{Bool}
end
function findtruncated(values::AbstractVector, strategy::TruncationByValue)
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
filter = (strategy.keep_below ? ≤(atol) : ≥(atol)) ∘ strategy.by
return findtruncated(values, truncfilter(filter))
end
function findtruncated_svd(values::AbstractVector, strategy::TruncationByValue)
strategy.by === abs || return findtruncated(values, strategy)
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
if strategy.keep_below
i = searchsortedfirst(values, atol; by=abs, rev=true)
return i:length(values)
else
i = searchsortedlast(values, atol; by=abs, rev=true)
return 1:i
end
end
function findtruncated(values::AbstractVector, strategy::TruncationByError)
I = sortperm(values; by=abs, rev=true)
I′ = _truncerr_impl(values, I; strategy.atol, strategy.rtol, strategy.p)
return I[I′]
end
function findtruncated_svd(values::AbstractVector, strategy::TruncationByError)
I = eachindex(values)
I′ = _truncerr_impl(values, I; strategy.atol, strategy.rtol, strategy.p)
return I[I′]
end
function _truncerr_impl(values::AbstractVector, I; atol::Real=0, rtol::Real=0, p::Real=2)
by = Base.Fix2(^, p) ∘ abs
Nᵖ = sum(by, values)
ϵᵖ = max(atol^p, rtol^p * Nᵖ)
# fast path to avoid checking all values
ϵᵖ ≥ Nᵖ && return Base.OneTo(0)
truncerrᵖ = zero(real(eltype(values)))
rank = length(values)
for i in reverse(I)
truncerrᵖ += by(values[i])
truncerrᵖ ≥ ϵᵖ && break
rank -= 1
end
return Base.OneTo(rank)
end
function findtruncated(values::AbstractVector, strategy::TruncationIntersection)
return mapreduce(Base.Fix1(findtruncated, values), _ind_intersect, strategy.components;
init=trues(length(values)))
end
function findtruncated_svd(values::AbstractVector, strategy::TruncationIntersection)
return mapreduce(Base.Fix1(findtruncated_svd, values), _ind_intersect,
strategy.components; init=trues(length(values)))
end
# when one of the ind selections is a bitvector, have to handle differently
function _ind_intersect(A::AbstractVector{Bool}, B::AbstractVector)
result = falses(length(A))
result[B] .= @view A[B]
return result
end
_ind_intersect(A::AbstractVector, B::AbstractVector{Bool}) = _ind_intersect(B, A)
_ind_intersect(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .& B
_ind_intersect(A, B) = intersect(A, B)