Skip to content

Commit e760bac

Browse files
committed
fixes and keep_below
1 parent 6f994a4 commit e760bac

6 files changed

Lines changed: 58 additions & 35 deletions

File tree

src/implementations/orthnull.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,11 @@ end
203203
# --------------------------------
204204
function null_truncation_strategy(; atol=nothing, rtol=nothing, maxnullity=nothing)
205205
if isnothing(maxnullity) && isnothing(atol) && isnothing(rtol)
206-
return NoTruncation()
206+
return notrunc()
207207
end
208208
atol = @something atol 0
209209
rtol = @something rtol 0
210-
trunc = trunctol(; atol, rtol, rev=true)
210+
trunc = trunctol(; atol, rtol, keep_below=true)
211211
return !isnothing(maxnullity) ? trunc & truncrank(maxnullity; rev=false) : trunc
212212
end
213213

src/implementations/truncation.jl

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,29 +41,37 @@ function findtruncated(values::AbstractVector, strategy::TruncationByOrder)
4141
return partialsortperm(values, 1:howmany; strategy.by, strategy.rev)
4242
end
4343
function findtruncated_sorted(values::AbstractVector, strategy::TruncationByOrder)
44-
howmany = min(strategy.howmany, length(values))
45-
return strategy.rev ? (1:howmany) : ((length(values) - howmany + 1):length(values))
44+
@assert strategy.by === abs
45+
if strategy.by === abs
46+
howmany = min(strategy.howmany, length(values))
47+
return strategy.rev ? (1:howmany) : ((length(values) - howmany + 1):length(values))
48+
else
49+
return findtruncated(values, strategy)
50+
end
4651
end
4752

4853
function findtruncated(values::AbstractVector, strategy::TruncationByFilter)
49-
ind = findall(strategy.filter, values)
50-
return ind
54+
# pre-allocate bitvector to enforce the filter function returns a Bool
55+
mask = similar(BitArray, eachindex(values))
56+
mask .= strategy.filter.(values)
57+
return mask
5158
end
5259

5360
function findtruncated(values::AbstractVector, strategy::TruncationByValue)
5461
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
55-
filter = (strategy.rev ? (atol) : (atol)) strategy.by
56-
return findall(filter, values)
62+
filter = (strategy.keep_below ? (atol) : (atol)) strategy.by
63+
return findtruncated(values, truncfilter(filter))
5764
end
58-
function findtruncated_sorted(values::AbstractVector, strategy::TruncationByValue)
65+
function findtruncated_svd(values::AbstractVector, strategy::TruncationByValue)
66+
strategy.by === abs || return findtruncated(values, strategy)
67+
5968
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
60-
@assert strategy.by === abs || strategy.by === real "sorting strategy incompatible with implementation"
61-
if strategy.rev
62-
i = searchsortedfirst(values, atol; by=strategy.by, rev=true)
63-
return i:length(values)
64-
else
65-
i = searchsortedlast(values, atol; by=strategy.by, rev=true)
69+
if strategy.keep_above
70+
i = searchsortedlast(values, atol; by=abs, rev=true)
6671
return 1:i
72+
else
73+
i = searchsortedfirst(values, atol; by=abs, rev=true)
74+
return i:length(values)
6775
end
6876
end
6977

@@ -97,10 +105,20 @@ function _truncerr_impl(values::AbstractVector, I; atol::Real=0, rtol::Real=0, p
97105
end
98106

99107
function findtruncated(values::AbstractVector, strategy::TruncationIntersection)
100-
inds = map(Base.Fix1(findtruncated, values), strategy.components)
101-
return intersect(inds...)
108+
return mapreduce(Base.Fix1(findtruncated, values), _ind_intersect, strategy.components;
109+
init=trues(length(values)))
102110
end
103111
function findtruncated_sorted(values::AbstractVector, strategy::TruncationIntersection)
104-
inds = map(Base.Fix1(findtruncated_sorted, values), strategy.components)
105-
return intersect(inds...)
112+
return mapreduce(Base.Fix1(findtruncated_sorted, values), _ind_intersect,
113+
strategy.components; init=trues(length(values)))
106114
end
115+
116+
# when one of the ind selections is a bitvector, have to handle differently
117+
function _ind_intersect(A::AbstractVector{Bool}, B::AbstractVector)
118+
result = falses(length(A))
119+
result[B] .= @view A[B]
120+
return result
121+
end
122+
_ind_intersect(A::AbstractVector, B::AbstractVector{Bool}) = _ind_intersect(B, A)
123+
_ind_intersect(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .& B
124+
_ind_intersect(A, B) = intersect(A, B)

src/interface/truncation.jl

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ Truncation strategy that does nothing, and keeps all the values.
4141
"""
4242
notrunc() = NoTruncation()
4343

44+
# TODO: Base.Ordering
45+
4446
"""
4547
TruncationByOrder(howmany::Int, by::Function, rev::Bool)
4648
@@ -59,7 +61,10 @@ end
5961
6062
Truncation strategy to keep the first `howmany` values when sorted according to `by` or the last `howmany` if `rev` is true.
6163
"""
62-
truncrank(howmany::Integer; by=abs, rev::Bool=true) = TruncationByOrder(howmany, by, rev)
64+
function truncrank(howmany::Integer; by=abs, rev::Bool=true)
65+
order = ...
66+
return TruncationByOrder(howmany, order)
67+
end
6368

6469
"""
6570
TruncationByFilter(filter::Function)
@@ -80,31 +85,31 @@ Truncation strategy to keep the values for which `filter` returns true.
8085
truncfilter(f) = TruncationByFilter(f)
8186

8287
"""
83-
TruncationByValue(atol::Real, rtol::Real, p::Real, by, rev::Bool=false)
88+
TruncationByValue(atol::Real, rtol::Real, p::Real, by, keep_below::Bool=false)
8489
85-
Truncation strategy to keep the values that satisfy `by(val) > max(atol, rtol * norm(values, p)`
86-
if `rev = false`, or discard them when `rev = true`.
90+
Truncation strategy to keep the values that satisfy `by(val) > max(atol, rtol * norm(values, p)`.
91+
If `keep_below = true`, discard these values instead.
8792
See also [`trunctol`](@ref)
8893
"""
8994
struct TruncationByValue{T<:Real,P<:Real,F} <: TruncationStrategy
9095
atol::T
9196
rtol::T
9297
p::P
9398
by::F
94-
rev::Bool
99+
keep_below::Bool
95100
end
96-
function TruncationByValue(atol::Real, rtol::Real, p::Real=2, by=abs, rev::Bool=true)
97-
return TruncationByValue(promote(atol, rtol)..., p, by, rev)
101+
function TruncationByValue(atol::Real, rtol::Real, p::Real=2, by=abs, keep_below::Bool=true)
102+
return TruncationByValue(promote(atol, rtol)..., p, by, keep_below)
98103
end
99104

100105
"""
101-
trunctol(; atol::Real=0, rtol::Real=0, p::Real=2, by=abs, rev::Bool=false)
106+
trunctol(; atol::Real=0, rtol::Real=0, p::Real=2, by=abs, keep_below::Bool=false)
102107
103-
Truncation strategy to keep the values that satisfy `by(val) > max(atol, rtol * norm(values, p)`
104-
if `rev = false`, or discard them when `rev = true`.
108+
Truncation strategy to keep the values that satisfy `by(val) > max(atol, rtol * norm(values, p)`.
109+
If `keep_below = true`, discard these values instead.
105110
"""
106-
function trunctol(; atol::Real=0, rtol::Real=0, p::Real=2, by=abs, rev::Bool=false)
107-
return TruncationByValue(atol, rtol, p, by, rev)
111+
function trunctol(; atol::Real=0, rtol::Real=0, p::Real=2, by=abs, keep_below::Bool=false)
112+
return TruncationByValue(atol, rtol, p, by, keep_below)
108113
end
109114

110115
"""

test/algorithms.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ end
5050
notrunc())
5151
end
5252

53-
alg = TruncatedAlgorithm(LAPACK_Simple(), trunctol(; atol=0.1, rev=true))
53+
alg = TruncatedAlgorithm(LAPACK_Simple(), trunctol(; atol=0.1, keep_below=true))
5454
for f in (eig_trunc!, eigh_trunc!, svd_trunc!)
5555
@test @constinferred(select_algorithm(eig_trunc!, A, alg)) === alg
5656
@test_throws ArgumentError select_algorithm(eig_trunc!, A, alg; trunc=(; maxrank=2))

test/orthnull.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ end
125125

126126
rtol = eps(real(T))
127127
for (trunc_orth, trunc_null) in (((; rtol=rtol), (; rtol=rtol)),
128-
(trunctol(; rtol), trunctol(; rtol, rev=true)))
128+
(trunctol(; rtol), trunctol(; rtol, keep_below=true)))
129129
V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=trunc_orth)
130130
N2 = @constinferred left_null!(copy!(Ac, A), N; trunc=trunc_null)
131131
@test V2 !== V

test/truncate.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationByOrder,
5353
@test @constinferred(findtruncated(values, strategy)) == [2, 3, 4, 5]
5454

5555
for strategy in
56-
(trunctol(; atol=0.4, rev=true), trunctol(; atol=0.2, by=identity, rev=true))
56+
(trunctol(; atol=0.4, keep_below=true), trunctol(; atol=0.2, by=identity, keep_below=true))
5757
@test @constinferred(findtruncated(values, strategy)) == [1, 4]
5858
end
59-
strategy = trunctol(; atol=0.2, rev=true)
59+
strategy = trunctol(; atol=0.2, keep_below=true)
6060
@test @constinferred(findtruncated(values, strategy)) == [1]
6161

6262
strategy = truncfilter(x -> 0.1 < x < 1)

0 commit comments

Comments
 (0)