@@ -11,10 +11,17 @@ function TruncationStrategy(; atol=nothing, rtol=nothing, maxrank=nothing)
1111 if isnothing (maxrank) && isnothing (atol) && isnothing (rtol)
1212 return NoTruncation ()
1313 elseif isnothing (maxrank)
14- @assert isnothing (rtol) " TODO: rtol"
15- return trunctol (atol)
14+ atol = @something atol 0
15+ rtol = @something rtol 0
16+ return TruncationKeepAbove (atol, rtol)
1617 else
17- return truncrank (maxrank)
18+ if isnothing (atol) && isnothing (rtol)
19+ return truncrank (maxrank)
20+ else
21+ atol = @something atol 0
22+ rtol = @something rtol 0
23+ return truncrank (maxrank) & TruncationKeepAbove (atol, rtol)
24+ end
1825 end
1926end
2027
@@ -82,6 +89,27 @@ Truncation strategy to discard the values that are larger than `atol` in absolut
8289"""
8390truncabove (atol) = TruncationKeepFiltered (≤ (atol) ∘ abs)
8491
92+ """
93+ TruncationComposition(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
94+ Compose two truncation strategies, keeping values common between the two strategies.
95+ """
96+ struct TruncationComposition{T<: Tuple{Vararg{TruncationStrategy}} } < :
97+ TruncationStrategy
98+ components:: T
99+ end
100+ function Base.:& (trunc1:: TruncationStrategy , trunc2:: TruncationStrategy )
101+ return TruncationComposition ((trunc1, trunc2))
102+ end
103+ function Base.:& (trunc1:: TruncationComposition , trunc2:: TruncationComposition )
104+ return TruncationComposition ((trunc1. components... , trunc2. components... ))
105+ end
106+ function Base.:& (trunc1:: TruncationComposition , trunc2:: TruncationStrategy )
107+ return TruncationComposition ((trunc1. components... , trunc2))
108+ end
109+ function Base.:& (trunc1:: TruncationStrategy , trunc2:: TruncationComposition )
110+ return TruncationComposition ((trunc1, trunc2. components... ))
111+ end
112+
85113# truncate!
86114# ---------
87115# Generic implementation: `findtruncated` followed by indexing
@@ -147,6 +175,11 @@ function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove)
147175 return 1 : i
148176end
149177
178+ function findtruncated (values:: AbstractVector , strategy:: TruncationComposition )
179+ inds = map (Base. Fix1 (findtruncated, values), strategy. components)
180+ return intersect (inds... )
181+ end
182+
150183"""
151184 TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)
152185
0 commit comments