Skip to content

Commit 77cc6c3

Browse files
committed
Flatten composition of compositions
1 parent 9707ab1 commit 77cc6c3

1 file changed

Lines changed: 14 additions & 7 deletions

File tree

src/implementations/truncation.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,21 @@ truncabove(atol) = TruncationKeepFiltered(≤(atol) ∘ abs)
9494
9595
Compose two truncation strategies, keeping values common between the two strategies.
9696
"""
97-
struct TruncationComposition{T1<:TruncationStrategy,T2<:TruncationStrategy} <:
97+
struct TruncationComposition{T<:Tuple{Vararg{TruncationStrategy}}} <:
9898
TruncationStrategy
99-
trunc1::T1
100-
trunc2::T2
99+
components::T
101100
end
102101
function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
103-
return TruncationComposition(trunc1, trunc2)
102+
return TruncationComposition((trunc1, trunc2))
103+
end
104+
function Base.:&(trunc1::TruncationComposition, trunc2::TruncationComposition)
105+
return TruncationComposition((trunc1.components..., trunc2.components...))
106+
end
107+
function Base.:&(trunc1::TruncationComposition, trunc2::TruncationStrategy)
108+
return TruncationComposition((trunc1.components..., trunc2))
109+
end
110+
function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationComposition)
111+
return TruncationComposition((trunc1, trunc2.components...))
104112
end
105113

106114
# truncate!
@@ -169,9 +177,8 @@ function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove)
169177
end
170178

171179
function findtruncated(values::AbstractVector, strategy::TruncationComposition)
172-
ind1 = findtruncated(values, strategy.trunc1)
173-
ind2 = findtruncated(values, strategy.trunc2)
174-
return ind1 ind2
180+
inds = map(Base.Fix1(findtruncated, values), strategy.components)
181+
return intersect(inds...)
175182
end
176183

177184
"""

0 commit comments

Comments
 (0)