-
Notifications
You must be signed in to change notification settings - Fork 6
TruncationStrategy types and constructors: consistency in names and implementations
#56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
6c1255b
07ce4e3
62f34ec
dddb770
f7db753
d72ce5a
24c3759
cb0b0ed
b8e6bb7
6f994a4
852774a
b5d60ed
5e5a532
fdb3657
5d720a3
7199486
7767066
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,7 @@ | |
| # --------- | ||
| # Generic implementation: `findtruncated` followed by indexing | ||
| function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy) | ||
| ind = findtruncated_sorted(diagview(S), strategy) | ||
| ind = findtruncated_svd(diagview(S), strategy) | ||
| return U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :] | ||
| end | ||
| function truncate!(::typeof(eig_trunc!), (D, V), strategy::TruncationStrategy) | ||
|
|
@@ -28,82 +28,89 @@ end | |
|
|
||
| # findtruncated | ||
| # ------------- | ||
| # Generic fallback | ||
| function findtruncated_svd(values::AbstractVector, strategy::TruncationStrategy) | ||
| return findtruncated(values, strategy) | ||
| end | ||
|
|
||
| # specific implementations for finding truncated values | ||
| findtruncated(values::AbstractVector, ::NoTruncation) = Colon() | ||
|
|
||
| function findtruncated(values::AbstractVector, strategy::TruncationKeepSorted) | ||
| function findtruncated(values::AbstractVector, strategy::TruncationByOrder) | ||
| howmany = min(strategy.howmany, length(values)) | ||
| return partialsortperm(values, 1:howmany; by=strategy.by, rev=strategy.rev) | ||
| return partialsortperm(values, 1:howmany; strategy.by, strategy.rev) | ||
| end | ||
| function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepSorted) | ||
| function findtruncated_svd(values::AbstractVector, strategy::TruncationByOrder) | ||
| strategy.by === abs || return findtruncated(values, strategy) | ||
| howmany = min(strategy.howmany, length(values)) | ||
| return 1:howmany | ||
| return strategy.rev ? (1:howmany) : ((length(values) - howmany + 1):length(values)) | ||
| end | ||
|
|
||
| # TODO: consider if worth using that values are sorted when filter is `<` or `>`. | ||
| function findtruncated(values::AbstractVector, strategy::TruncationKeepFiltered) | ||
| ind = findall(strategy.filter, values) | ||
| return ind | ||
| function findtruncated(values::AbstractVector, strategy::TruncationByFilter) | ||
| return strategy.filter.(values)::AbstractVector{Bool} | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it really worth returning a bool vector here, instead of just indices? Is
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, is the type annotation really useful? Is this a JET thing again?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I really don't have any opinion here. @mtfishman, what do you think? I am fine with either way.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm fine either way, I think the main considerations were:
|
||
| end | ||
|
|
||
| function findtruncated(values::AbstractVector, strategy::TruncationKeepBelow) | ||
| function findtruncated(values::AbstractVector, strategy::TruncationByValue) | ||
| atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) | ||
| return findall(≤(atol) ∘ strategy.by, values) | ||
| filter = (strategy.keep_below ? ≤(atol) : ≥(atol)) ∘ strategy.by | ||
| return findtruncated(values, truncfilter(filter)) | ||
| end | ||
| function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepBelow) | ||
| function findtruncated_svd(values::AbstractVector, strategy::TruncationByValue) | ||
| strategy.by === abs || return findtruncated(values, strategy) | ||
|
Jutho marked this conversation as resolved.
|
||
| atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) | ||
| i = searchsortedfirst(values, atol; by=strategy.by, rev=true) | ||
| return i:length(values) | ||
| end | ||
|
|
||
| function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove) | ||
| atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) | ||
| return findall(≥(atol) ∘ strategy.by, values) | ||
| end | ||
| function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepAbove) | ||
| atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) | ||
| i = searchsortedlast(values, atol; by=strategy.by, rev=true) | ||
| return 1:i | ||
| end | ||
|
|
||
| function findtruncated(values::AbstractVector, strategy::TruncationIntersection) | ||
| inds = map(Base.Fix1(findtruncated, values), strategy.components) | ||
| return intersect(inds...) | ||
| end | ||
| function findtruncated_sorted(values::AbstractVector, strategy::TruncationIntersection) | ||
| inds = map(Base.Fix1(findtruncated_sorted, values), strategy.components) | ||
| return intersect(inds...) | ||
| if strategy.keep_below | ||
| i = searchsortedfirst(values, atol; by=abs, rev=true) | ||
| return i:length(values) | ||
| else | ||
| i = searchsortedlast(values, atol; by=abs, rev=true) | ||
| return 1:i | ||
| end | ||
| end | ||
|
|
||
| function findtruncated(values::AbstractVector, strategy::TruncationError) | ||
| function findtruncated(values::AbstractVector, strategy::TruncationByError) | ||
| I = sortperm(values; by=abs, rev=true) | ||
| I′ = _truncerr_impl(values, I, strategy) | ||
| I′ = _truncerr_impl(values, I; strategy.atol, strategy.rtol, strategy.p) | ||
| return I[I′] | ||
| end | ||
| function findtruncated_sorted(values::AbstractVector, strategy::TruncationError) | ||
| function findtruncated_svd(values::AbstractVector, strategy::TruncationByError) | ||
| I = eachindex(values) | ||
| I′ = _truncerr_impl(values, I, strategy) | ||
| I′ = _truncerr_impl(values, I; strategy.atol, strategy.rtol, strategy.p) | ||
| return I[I′] | ||
| end | ||
| function _truncerr_impl(values::AbstractVector, I, strategy::TruncationError) | ||
| Nᵖ = sum(Base.Fix2(^, strategy.p) ∘ abs, values) | ||
| ϵᵖ = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * Nᵖ) | ||
| function _truncerr_impl(values::AbstractVector, I; atol::Real=0, rtol::Real=0, p::Real=2) | ||
| by = Base.Fix2(^, p) ∘ abs | ||
| Nᵖ = sum(by, values) | ||
| ϵᵖ = max(atol^p, rtol^p * Nᵖ) | ||
|
|
||
| # fast path to avoid checking all values | ||
| ϵᵖ ≥ Nᵖ && return Base.OneTo(0) | ||
|
|
||
| truncerrᵖ = zero(real(eltype(values))) | ||
| rank = length(values) | ||
| for i in reverse(I) | ||
| truncerrᵖ += abs(values[i])^strategy.p | ||
| if truncerrᵖ ≥ ϵᵖ | ||
| break | ||
| else | ||
| rank -= 1 | ||
| end | ||
| truncerrᵖ += by(values[i]) | ||
| truncerrᵖ ≥ ϵᵖ && break | ||
| rank -= 1 | ||
| end | ||
|
|
||
| return Base.OneTo(rank) | ||
| end | ||
|
|
||
| # Generic fallback | ||
| function findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy) | ||
| return findtruncated(values, strategy) | ||
| function findtruncated(values::AbstractVector, strategy::TruncationIntersection) | ||
| return mapreduce(Base.Fix1(findtruncated, values), _ind_intersect, strategy.components; | ||
| init=trues(length(values))) | ||
| end | ||
| function findtruncated_svd(values::AbstractVector, strategy::TruncationIntersection) | ||
| return mapreduce(Base.Fix1(findtruncated_svd, values), _ind_intersect, | ||
| strategy.components; init=trues(length(values))) | ||
| end | ||
|
|
||
| # when one of the ind selections is a bitvector, have to handle differently | ||
| function _ind_intersect(A::AbstractVector{Bool}, B::AbstractVector) | ||
| result = falses(length(A)) | ||
| result[B] .= @view A[B] | ||
| return result | ||
| end | ||
| _ind_intersect(A::AbstractVector, B::AbstractVector{Bool}) = _ind_intersect(B, A) | ||
| _ind_intersect(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .& B | ||
| _ind_intersect(A, B) = intersect(A, B) | ||
Uh oh!
There was an error while loading. Please reload this page.