From 6c1255bed63e897eb02adf8a8a4c0c7237b8af26 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 20 Sep 2025 08:13:36 +0200 Subject: [PATCH 01/17] Change `trunctol` to use `TruncationKeepBelow` --- src/interface/truncation.jl | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/interface/truncation.jl b/src/interface/truncation.jl index 8710570df..6d63ca8f1 100644 --- a/src/interface/truncation.jl +++ b/src/interface/truncation.jl @@ -69,13 +69,6 @@ struct TruncationKeepFiltered{F} <: TruncationStrategy filter::F end -""" - trunctol(val::Real; by=abs) - -Truncation strategy to discard the values that are smaller than `val` according to `by`. -""" -trunctol(val::Real; by=abs) = TruncationKeepFiltered(≥(val) ∘ by) - """ truncabove(val::Real; by=abs) @@ -114,6 +107,15 @@ function TruncationKeepBelow(atol::Real, rtol::Real, p::Real=2, by=abs) return TruncationKeepBelow(promote(atol, rtol)..., p, by) end +""" + trunctol(; atol::Real, rtol::Real, p::Real=2, by=abs) + +Truncation strategy to discard all values that satisfy `by(val) < max(atol, rtol * norm(values))`. +""" +function trunctol(; atol::Real=0, rtol::Real=0, p::Real=2, by=abs) + return TruncationKeepBelow(; atol, rtol, p, by) +end + """ TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...) From 07ce4e3069f44fb5dfddd4945b764a40a2ba0fb7 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 20 Sep 2025 09:03:59 +0200 Subject: [PATCH 02/17] Refactor truncation names --- src/MatrixAlgebraKit.jl | 18 +++-- src/implementations/orthnull.jl | 2 +- src/implementations/truncation.jl | 87 ++++++++++++------------- src/interface/truncation.jl | 105 +++++++++++++----------------- 4 files changed, 97 insertions(+), 115 deletions(-) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index a2fe9a086..a88c1747e 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -28,16 +28,14 @@ export left_polar!, right_polar! export left_orth, right_orth, left_null, right_null export left_orth!, right_orth!, left_null!, right_null! -export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, - LAPACK_Simple, LAPACK_Expert, - LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations, - LAPACK_DivideAndConquer, LAPACK_Jacobi, - LQViaTransposedQR, - CUSOLVER_Simple, - CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer, - ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi, ROCSOLVER_DivideAndConquer, ROCSOLVER_Bisection, - DiagonalAlgorithm -export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered, truncerror +export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations, LAPACK_DivideAndConquer, LAPACK_Jacobi +export LQViaTransposedQR +export DiagonalAlgorithm +export CUSOLVER_Simple, CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer +export ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi, + ROCSOLVER_DivideAndConquer, ROCSOLVER_Bisection + +export notrunc, truncrank, trunctol, truncerror, truncfilter VERSION >= v"1.11.0-DEV.469" && eval(Expr(:public, :default_algorithm, :findtruncated, :findtruncated_sorted, diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index c6dc6248a..693a3ab9c 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -207,7 +207,7 @@ function null_truncation_strategy(; atol=nothing, rtol=nothing, maxnullity=nothi end atol = @something atol 0 rtol = @something rtol 0 - trunc = TruncationKeepBelow(atol, rtol) + trunc = trunctol(; atol, rtol, rev=false) return !isnothing(maxnullity) ? trunc & truncrank(maxnullity; rev=false) : trunc end diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 9f7d07193..5942b6ef9 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -28,82 +28,79 @@ end # findtruncated # ------------- +# Generic fallback +function findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy) + return findtruncated(values, strategy) +end + # specific implementations for finding truncated values findtruncated(values::AbstractVector, ::NoTruncation) = Colon() -function findtruncated(values::AbstractVector, strategy::TruncationKeepSorted) +function findtruncated(values::AbstractVector, strategy::TruncationByOrder) howmany = min(strategy.howmany, length(values)) - return partialsortperm(values, 1:howmany; by=strategy.by, rev=strategy.rev) + return partialsortperm(values, 1:howmany; strategy.by, strategy.rev) end -function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepSorted) +function findtruncated_sorted(values::AbstractVector, strategy::TruncationByOrder) howmany = min(strategy.howmany, length(values)) return 1:howmany end -# TODO: consider if worth using that values are sorted when filter is `<` or `>`. -function findtruncated(values::AbstractVector, strategy::TruncationKeepFiltered) +function findtruncated(values::AbstractVector, strategy::TruncationByFilter) ind = findall(strategy.filter, values) return ind end -function findtruncated(values::AbstractVector, strategy::TruncationKeepBelow) - atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) - return findall(≤(atol) ∘ strategy.by, values) -end -function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepBelow) - atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) - i = searchsortedfirst(values, atol; by=strategy.by, rev=true) - return i:length(values) -end - -function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove) +function findtruncated(values::AbstractVector, strategy::TruncationByValue) atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) - return findall(≥(atol) ∘ strategy.by, values) + filter = (strategy.rev ? ≥(atol) : ≤(atol)) ∘ strategy.by + return findall(filter, values) end -function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepAbove) +function findtruncated_sorted(values::AbstractVector, strategy::TruncationByValue) atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) - i = searchsortedlast(values, atol; by=strategy.by, rev=true) - return 1:i -end - -function findtruncated(values::AbstractVector, strategy::TruncationIntersection) - inds = map(Base.Fix1(findtruncated, values), strategy.components) - return intersect(inds...) -end -function findtruncated_sorted(values::AbstractVector, strategy::TruncationIntersection) - inds = map(Base.Fix1(findtruncated_sorted, values), strategy.components) - return intersect(inds...) + @assert strategy.by === abs || strategy.by === real "sorting strategy incompatible with implementation" + if strategy.rev + i = searchsortedlast(values, atol; by=strategy.by, rev=true) + return 1:i + else + i = searchsortedfirst(values, atol; by=strategy.by, rev=true) + return i:length(values) + end end -function findtruncated(values::AbstractVector, strategy::TruncationError) +function findtruncated(values::AbstractVector, strategy::TruncationByError) I = sortperm(values; by=abs, rev=true) - I′ = _truncerr_impl(values, I, strategy) + I′ = _truncerr_impl(values, I; strategy.atol, strategy.rtol, strategy.p) return I[I′] end -function findtruncated_sorted(values::AbstractVector, strategy::TruncationError) +function findtruncated_sorted(values::AbstractVector, strategy::TruncationByError) I = eachindex(values) - I′ = _truncerr_impl(values, I, strategy) + I′ = _truncerr_impl(values, I; strategy.atol, strategy.rtol, strategy.p) return I[I′] end -function _truncerr_impl(values::AbstractVector, I, strategy::TruncationError) - Nᵖ = sum(Base.Fix2(^, strategy.p) ∘ abs, values) - ϵᵖ = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * Nᵖ) +function _truncerr_impl(values::AbstractVector, I; atol::Real=0, rtol::Real=0, p::Real=2) + by = Base.Fix2(^, p) ∘ abs + Nᵖ = sum(by, values) + ϵᵖ = max(atol^p, rtol^p * Nᵖ) + + # fast path to avoid checking all values ϵᵖ ≥ Nᵖ && return Base.OneTo(0) truncerrᵖ = zero(real(eltype(values))) rank = length(values) for i in reverse(I) - truncerrᵖ += abs(values[i])^strategy.p - if truncerrᵖ ≥ ϵᵖ - break - else - rank -= 1 - end + truncerrᵖ += by(values[i]) + truncerrᵖ ≥ ϵᵖ && break + rank -= 1 end + return Base.OneTo(rank) end -# Generic fallback -function findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy) - return findtruncated(values, strategy) +function findtruncated(values::AbstractVector, strategy::TruncationIntersection) + inds = map(Base.Fix1(findtruncated, values), strategy.components) + return intersect(inds...) +end +function findtruncated_sorted(values::AbstractVector, strategy::TruncationIntersection) + inds = map(Base.Fix1(findtruncated_sorted, values), strategy.components) + return intersect(inds...) end diff --git a/src/interface/truncation.jl b/src/interface/truncation.jl index 6d63ca8f1..9855ae542 100644 --- a/src/interface/truncation.jl +++ b/src/interface/truncation.jl @@ -14,14 +14,14 @@ function TruncationStrategy(; atol=nothing, rtol=nothing, maxrank=nothing) elseif isnothing(maxrank) atol = @something atol 0 rtol = @something rtol 0 - return TruncationKeepAbove(atol, rtol) + return trunctol(; atol, rtol) else if isnothing(atol) && isnothing(rtol) return truncrank(maxrank) else atol = @something atol 0 rtol = @something rtol 0 - return truncrank(maxrank) & TruncationKeepAbove(atol, rtol) + return truncrank(maxrank) & trunctol(; atol, rtol) end end end @@ -42,78 +42,94 @@ Truncation strategy that does nothing, and keeps all the values. notrunc() = NoTruncation() """ - TruncationKeepSorted(howmany::Int, by::Function, rev::Bool) + TruncationByOrder(howmany::Int, by::Function, rev::Bool) Truncation strategy to keep the first `howmany` values when sorted according to `by` in increasing (decreasing) order if `rev` is false (true). + See also [`truncrank`](@ref). """ -struct TruncationKeepSorted{F} <: TruncationStrategy +struct TruncationByOrder{F} <: TruncationStrategy howmany::Int by::F rev::Bool end """ - truncrank(howmany::Int; by=abs, rev=true) + truncrank(howmany::Integer; by=abs, rev::Bool=true) Truncation strategy to keep the first `howmany` values when sorted according to `by` or the last `howmany` if `rev` is true. """ -truncrank(howmany::Int; by=abs, rev=true) = TruncationKeepSorted(howmany, by, rev) +truncrank(howmany::Integer; by=abs, rev::Bool=true) = TruncationByOrder(howmany, by, rev) """ - TruncationKeepFiltered(filter::Function) + TruncationByFilter(filter::Function) Truncation strategy to keep the values for which `filter` returns true. + +See also [`truncfilter`](@ref). """ -struct TruncationKeepFiltered{F} <: TruncationStrategy +struct TruncationByFilter{F} <: TruncationStrategy filter::F end """ - truncabove(val::Real; by=abs) + truncfilter(filter) + +Truncation strategy to keep the values for which `filter` returns true. +""" +truncfilter(f) = TruncationByFilter(f) -Truncation strategy to discard the values that are larger than `val` according to `by`. """ -truncabove(val::Real; by=abs) = TruncationKeepFiltered(≤(val) ∘ by) + TruncationByValue(atol::Real, rtol::Real, p::Real, by, rev::Bool=true) -struct TruncationKeepAbove{T<:Real,P<:Real,F} <: TruncationStrategy +Truncation strategy to keep the values that satisfy `by(val) < max(atol, rtol * norm(values, p)` +if `rev = true`, or discard them when `rev = false`. +See also [`trunctol`](@ref) +""" +struct TruncationByValue{T<:Real,P<:Real,F} <: TruncationStrategy atol::T rtol::T p::P by::F + rev::Bool end -function TruncationKeepAbove(; atol::Real, rtol::Real, p::Real=2, by=abs) - return TruncationKeepAbove(atol, rtol, p, by) +function TruncationByValue(atol::Real, rtol::Real, p::Real=2, by=abs, rev::Bool=true) + return TruncationByValue(promote(atol, rtol)..., p, by, rev) end -function TruncationKeepAbove(atol::Real, rtol::Real, p::Real=2, by=abs) - return TruncationKeepAbove(promote(atol, rtol)..., p, by) + +""" + trunctol(; atol::Real=0, rtol::Real=0, p::Real=2, by=abs, ) + +Truncation strategy to keep the values that satisfy `by(val) < max(atol, rtol * norm(values, p)` +if `rev = true`, or discard them when `rev = false`. +""" +function trunctol(; atol::Real=0, rtol::Real=0, p::Real=2, by=abs, rev::Bool=true) + return TruncationByValue(atol, rtol, p, by, rev) end """ - TruncationKeepBelow(; atol::Real, rtol::Real, p=2, by=abs) + TruncationByError(; atol::Real, rtol::Real, p::Real) -Truncation strategy to discard the values that are smaller than the norm of the values. +Truncation strategy to discard values until the error caused by the discarded values exceeds some tolerances. +See also [`truncerror`](@ref). """ -struct TruncationKeepBelow{T<:Real,P<:Real,F} <: TruncationStrategy +struct TruncationByError{T<:Real,P<:Real} <: TruncationStrategy atol::T rtol::T p::P - by::F -end -function TruncationKeepBelow(; atol::Real, rtol::Real, p::Real=2, by=abs) - return TruncationKeepBelow(atol, rtol, p, by) end -function TruncationKeepBelow(atol::Real, rtol::Real, p::Real=2, by=abs) - return TruncationKeepBelow(promote(atol, rtol)..., p, by) +function TruncationError(atol::Real, rtol::Real, p::Real=2) + return TruncationError(promote(atol, rtol)..., p) end """ - trunctol(; atol::Real, rtol::Real, p::Real=2, by=abs) + truncerror(; atol::Real=0, rtol::Real=0, p::Real=2) -Truncation strategy to discard all values that satisfy `by(val) < max(atol, rtol * norm(values))`. +Create a truncation strategy for truncating such that the error in the factorization +is smaller than `max(atol, rtol * norm)`, where the error is determined using the `p`-norm. """ -function trunctol(; atol::Real=0, rtol::Real=0, p::Real=2, by=abs) - return TruncationKeepBelow(; atol, rtol, p, by) +function truncerror(; atol::Real=0, rtol::Real=0, p::Real=2) + return TruncationByError(promote(atol, rtol)..., p) end """ @@ -121,8 +137,7 @@ end Composition of multiple truncation strategies, keeping values common between them. """ -struct TruncationIntersection{T<:Tuple{Vararg{TruncationStrategy}}} <: - TruncationStrategy +struct TruncationIntersection{T<:Tuple{Vararg{TruncationStrategy}}} <: TruncationStrategy components::T end function TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...) @@ -141,31 +156,3 @@ end function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationIntersection) return TruncationIntersection((trunc1, trunc2.components...)) end - -""" - TruncationError(; atol::Real, rtol::Real, p::Real) - -Truncation strategy to discard values until the error caused by the discarded values exceeds some tolerances. -See also [`truncerror`](@ref). -""" -struct TruncationError{T<:Real,P<:Real} <: TruncationStrategy - atol::T - rtol::T - p::P -end -function TruncationError(; atol::Real, rtol::Real, p::Real=2) - return TruncationError(atol, rtol, p) -end -function TruncationError(atol::Real, rtol::Real, p::Real=2) - return TruncationError(promote(atol, rtol)..., p) -end - -""" - truncerror(; atol::Real=0, rtol::Real=0, p::Real=2) - -Create a truncation strategy for truncating such that the error in the factorization -is smaller than `max(atol, rtol * norm)`, where the error is determined using the `p`-norm. -""" -function truncerror(; atol::Real=0, rtol::Real=0, p::Real=2) - return TruncationError(promote(atol, rtol)..., p) -end From 62f34ec57c875cb7dcd48bd8b105c78c798c4a6e Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 20 Sep 2025 09:04:04 +0200 Subject: [PATCH 03/17] Bump v0.4 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8d4a6b89c..7ff516244 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MatrixAlgebraKit" uuid = "6c742aac-3347-4629-af66-fc926824e5e4" authors = ["Jutho and contributors"] -version = "0.3.2" +version = "0.4.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" From dddb770b097daf6bc74628925707fed336a59db5 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 21 Sep 2025 09:57:44 +0200 Subject: [PATCH 04/17] update tests --- test/algorithms.jl | 12 +++---- test/chainrules.jl | 4 +-- test/eig.jl | 2 +- test/eigh.jl | 2 +- test/orthnull.jl | 9 ++--- test/svd.jl | 12 +++---- test/truncate.jl | 89 +++++++++++++++++++++------------------------- 7 files changed, 62 insertions(+), 68 deletions(-) diff --git a/test/algorithms.jl b/test/algorithms.jl index 9f9c4542c..7648384d2 100644 --- a/test/algorithms.jl +++ b/test/algorithms.jl @@ -1,8 +1,8 @@ using MatrixAlgebraKit using Test using TestExtras -using MatrixAlgebraKit: LAPACK_SVDAlgorithm, NoTruncation, PolarViaSVD, TruncatedAlgorithm, - TruncationKeepBelow, default_algorithm, select_algorithm +using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm, + default_algorithm, select_algorithm @testset "default_algorithm" begin A = randn(3, 3) @@ -38,19 +38,19 @@ end A = randn(3, 3) for f in (svd_trunc!, svd_trunc) @test @constinferred(select_algorithm(f, A)) === - TruncatedAlgorithm(LAPACK_DivideAndConquer(), NoTruncation()) + TruncatedAlgorithm(LAPACK_DivideAndConquer(), notrunc()) end for f in (eig_trunc!, eig_trunc) @test @constinferred(select_algorithm(f, A)) === - TruncatedAlgorithm(LAPACK_Expert(), NoTruncation()) + TruncatedAlgorithm(LAPACK_Expert(), notrunc()) end for f in (eigh_trunc!, eigh_trunc) @test @constinferred(select_algorithm(f, A)) === TruncatedAlgorithm(LAPACK_MultipleRelativelyRobustRepresentations(), - NoTruncation()) + notrunc()) end - alg = TruncatedAlgorithm(LAPACK_Simple(), TruncationKeepBelow(0.1, 0.0)) + alg = TruncatedAlgorithm(LAPACK_Simple(), trunctol(; atol=0.1, rev=false)) for f in (eig_trunc!, eigh_trunc!, svd_trunc!) @test @constinferred(select_algorithm(eig_trunc!, A, alg)) === alg @test_throws ArgumentError select_algorithm(eig_trunc!, A, alg; trunc=(; maxrank=2)) diff --git a/test/chainrules.jl b/test/chainrules.jl index a48d2aaec..2073e9ee4 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -284,7 +284,7 @@ end output_tangent=(ΔU[:, 1:r], ΔS[1:r, 1:r], ΔVᴴ[1:r, :]), atol=atol, rtol=rtol) end - truncalg = TruncatedAlgorithm(alg, trunctol(S[1, 1] / 2)) + truncalg = TruncatedAlgorithm(alg, trunctol(; atol=S[1, 1] / 2)) r = findlast(>=(S[1, 1] / 2), diagview(S)) test_rrule(copy_svd_trunc, A, truncalg ⊢ NoTangent(); output_tangent=(ΔU[:, 1:r], ΔS[1:r, 1:r], ΔVᴴ[1:r, :]), @@ -302,7 +302,7 @@ end atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) end r = findlast(>=(S[1, 1] / 2), diagview(S)) - test_rrule(config, svd_trunc, A; fkwargs=(; trunc=trunctol(S[1, 1] / 2)), + test_rrule(config, svd_trunc, A; fkwargs=(; trunc=trunctol(; atol=S[1, 1] / 2)), output_tangent=(ΔU[:, 1:r], ΔS[1:r, 1:r], ΔVᴴ[1:r, :]), atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) end diff --git a/test/eig.jl b/test/eig.jl index b3e8c600a..b6951d340 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -48,7 +48,7 @@ end @test A * V1 ≈ V1 * D1 s = 1 + sqrt(eps(real(T))) - trunc = trunctol(s * abs(D₀[r + 1])) + trunc = trunctol(; atol=s * abs(D₀[r + 1])) D2, V2 = @constinferred eig_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test A * V2 ≈ V2 * D2 diff --git a/test/eigh.jl b/test/eigh.jl index 9b8811607..62a5d7ca1 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -52,7 +52,7 @@ end @test A * V1 ≈ V1 * D1 @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] - trunc = trunctol(s * D₀[r + 1]) + trunc = trunctol(; atol=s * D₀[r + 1]) D2, V2 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test isisometry(V2) diff --git a/test/orthnull.jl b/test/orthnull.jl index c402e8303..5a1cdf059 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -3,7 +3,6 @@ using Test using TestExtras using StableRNGs using LinearAlgebra: LinearAlgebra, I, mul! -using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow using MatrixAlgebraKit: LAPACK_SVDAlgorithm, check_input, copy_input, default_svd_algorithm, initialize_output, AbstractAlgorithm @@ -33,10 +32,12 @@ end function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), A::LinearMap) return LinearMap.(initialize_output(right_orth!, parent(A))) end -function MatrixAlgebraKit.check_input(::typeof(left_orth!), A::LinearMap, VC, alg::AbstractAlgorithm) +function MatrixAlgebraKit.check_input(::typeof(left_orth!), A::LinearMap, VC, + alg::AbstractAlgorithm) return check_input(left_orth!, parent(A), parent.(VC), alg) end -function MatrixAlgebraKit.check_input(::typeof(right_orth!), A::LinearMap, VC, alg::AbstractAlgorithm) +function MatrixAlgebraKit.check_input(::typeof(right_orth!), A::LinearMap, VC, + alg::AbstractAlgorithm) return check_input(right_orth!, parent(A), parent.(VC), alg) end function MatrixAlgebraKit.default_svd_algorithm(::Type{LinearMap{A}}; kwargs...) where {A} @@ -124,7 +125,7 @@ end rtol = eps(real(T)) for (trunc_orth, trunc_null) in (((; rtol=rtol), (; rtol=rtol)), - (TruncationKeepAbove(0, rtol), TruncationKeepBelow(0, rtol))) + (trunctol(; rtol), trunctol(; rtol, rev=false))) V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=trunc_orth) N2 = @constinferred left_null!(copy!(Ac, A), N; trunc=trunc_null) @test V2 !== V diff --git a/test/svd.jl b/test/svd.jl index f3a1a4cba..3f1f5fd8d 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -3,7 +3,7 @@ using Test using TestExtras using StableRNGs using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef, norm -using MatrixAlgebraKit: TruncatedAlgorithm, TruncationKeepAbove, diagview, isisometry +using MatrixAlgebraKit: TruncatedAlgorithm, diagview, isisometry const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) @@ -113,15 +113,15 @@ end @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] s = 1 + sqrt(eps(real(T))) - trunc2 = trunctol(s * S₀[r + 1]) + trunc = trunctol(; atol=s * S₀[r + 1]) - U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc=trunctol(s * S₀[r + 1])) + U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc) @test length(S2.diag) == r @test U1 ≈ U2 @test S1 ≈ S2 @test V1ᴴ ≈ V2ᴴ - trunc = truncerror(; atol = s * norm(@view(S₀[(r + 1):end]))) + trunc = truncerror(; atol=s * norm(@view(S₀[(r + 1):end]))) U3, S3, V3ᴴ = @constinferred svd_trunc(A; alg, trunc) @test length(S3.diag) == r @test U1 ≈ U3 @@ -147,7 +147,7 @@ end A = U * S * Vᴴ for trunc_fun in ((rtol, maxrank) -> (; rtol, maxrank), - (rtol, maxrank) -> truncrank(maxrank) & TruncationKeepAbove(0, rtol)) + (rtol, maxrank) -> truncrank(maxrank) & trunctol(; rtol)) U1, S1, V1ᴴ = svd_trunc(A; alg, trunc=trunc_fun(0.2, 1)) @test length(S1.diag) == 1 @test S1.diag ≈ S.diag[1:1] rtol = sqrt(eps(real(T))) @@ -166,7 +166,7 @@ end S = Diagonal([0.9, 0.3, 0.1, 0.01]) Vᴴ = qr_compact(randn(rng, T, m, m))[1] A = U * S * Vᴴ - alg = TruncatedAlgorithm(LAPACK_DivideAndConquer(), TruncationKeepAbove(0.2, 0.0)) + alg = TruncatedAlgorithm(LAPACK_DivideAndConquer(), trunctol(; atol=0.2)) U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg) @test diagview(S2) ≈ diagview(S)[1:2] rtol = sqrt(eps(real(T))) @test_throws ArgumentError svd_trunc(A; alg, trunc=(; maxrank=2)) diff --git a/test/truncate.jl b/test/truncate.jl index a2701c2a0..266515f58 100644 --- a/test/truncate.jl +++ b/test/truncate.jl @@ -1,32 +1,35 @@ using MatrixAlgebraKit using Test using TestExtras -using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbove, - TruncationKeepBelow, TruncationStrategy, findtruncated, +using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationByOrder, + TruncationByValue, TruncationStrategy, findtruncated, findtruncated_sorted @testset "truncate" begin trunc = @constinferred TruncationStrategy() @test trunc isa NoTruncation - trunc = @constinferred TruncationStrategy(; atol=1e-2, rtol=1e-3) - @test trunc isa TruncationKeepAbove - @test trunc == TruncationKeepAbove(1e-2, 1e-3) - @test trunc.atol == 1e-2 - @test trunc.rtol == 1e-3 + atol = 1e-2 + rtol = 1e-3 + maxrank = 10 - trunc = @constinferred TruncationStrategy(; maxrank=10) - @test trunc isa TruncationKeepSorted - @test trunc == truncrank(10) - @test trunc.howmany == 10 + trunc = @constinferred TruncationStrategy(; atol, rtol) + @test trunc isa TruncationByValue + @test trunc == trunctol(; atol, rtol) + @test trunc.atol == atol + @test trunc.rtol == rtol + @test trunc.rev + + trunc = @constinferred TruncationStrategy(; maxrank) + @test trunc isa TruncationByOrder + @test trunc == truncrank(maxrank) + @test trunc.howmany == maxrank @test trunc.by == abs - @test trunc.rev == true + @test trunc.rev - trunc = @constinferred TruncationStrategy(; atol=1e-2, rtol=1e-3, maxrank=10) + trunc = @constinferred TruncationStrategy(; atol, rtol, maxrank) @test trunc isa TruncationIntersection - @test trunc == truncrank(10) & TruncationKeepAbove(1e-2, 1e-3) - @test trunc.components[1] == truncrank(10) - @test trunc.components[2] == TruncationKeepAbove(1e-2, 1e-3) + @test trunc == truncrank(maxrank) & trunctol(; atol, rtol) values = [1, 0.9, 0.5, -0.3, 0.01] @test @constinferred(findtruncated(values, truncrank(2))) == 1:2 @@ -35,42 +38,32 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbov @test @constinferred(findtruncated_sorted(values, truncrank(2))) === 1:2 values = [1, 0.9, 0.5, -0.3, 0.01] - for strategy in (TruncationKeepAbove(; atol=0.4, rtol=0), - TruncationKeepAbove(0.4, 0)) - @test @constinferred(findtruncated(values, strategy)) == 1:3 - @test @constinferred(findtruncated_sorted(values, strategy)) === 1:3 - end - for strategy in (TruncationKeepBelow(; atol=0.4, rtol=0), - TruncationKeepBelow(0.4, 0)) - @test @constinferred(findtruncated(values, strategy)) == 4:5 - @test @constinferred(findtruncated_sorted(values, strategy)) === 4:5 - end + strategy = trunctol(; atol=0.4) + @test @constinferred(findtruncated(values, strategy)) == 1:3 + @test @constinferred(findtruncated_sorted(values, strategy)) === 1:3 + strategy = trunctol(; atol=0.4, rev=false) + @test @constinferred(findtruncated(values, strategy)) == 4:5 + @test @constinferred(findtruncated_sorted(values, strategy)) === 4:5 values = [0.01, 1, 0.9, -0.3, 0.5] - for strategy in (TruncationKeepAbove(; atol=0.4, rtol=0), - TruncationKeepAbove(; atol=0.4, rtol=0, by=abs), - TruncationKeepAbove(0.4, 0), - TruncationKeepAbove(; atol=0.2, rtol=0.0, by=identity)) + for strategy in (trunctol(; atol=0.4), trunctol(; atol=0.2, by=identity)) @test @constinferred(findtruncated(values, strategy)) == [2, 3, 5] end - for strategy in (TruncationKeepAbove(; atol=0.2, rtol=0), - TruncationKeepAbove(; atol=0.2, rtol=0, by=abs), - TruncationKeepAbove(0.2, 0)) - @test @constinferred(findtruncated(values, strategy)) == [2, 3, 4, 5] - end - for strategy in (TruncationKeepBelow(; atol=0.4, rtol=0), - TruncationKeepBelow(; atol=0.4, rtol=0, by=abs), - TruncationKeepBelow(0.4, 0), - TruncationKeepBelow(; atol=0.2, rtol=0.0, by=identity)) + strategy = trunctol(; atol=0.2) + @test @constinferred(findtruncated(values, strategy)) == [2, 3, 4, 5] + + for strategy in + (trunctol(; atol=0.4, rev=false), trunctol(; atol=0.2, by=identity, rev=false)) @test @constinferred(findtruncated(values, strategy)) == [1, 4] end - for strategy in (TruncationKeepBelow(; atol=0.2, rtol=0), - TruncationKeepBelow(; atol=0.2, rtol=0, by=abs), - TruncationKeepBelow(0.2, 0)) - @test @constinferred(findtruncated(values, strategy)) == [1] - end - for strategy in (truncerror(; atol=0.2, rtol=0),) - @test issetequal(@constinferred(findtruncated(values, strategy)), 2:5) - @test @constinferred(findtruncated_sorted(sort(values; by=abs, rev=true), strategy)) == 1:4 - end + strategy = trunctol(; atol=0.2, rev=false) + @test @constinferred(findtruncated(values, strategy)) == [1] + + strategy = truncfilter(x -> 0.1 < x < 1) + @test @constinferred(findtruncated(values, strategy)) == [3, 5] + + strategy = truncerror(; atol=0.2, rtol=0) + @test issetequal(@constinferred(findtruncated(values, strategy)), 2:5) + vals_sorted = sort(values; by=abs, rev=true) + @test @constinferred(findtruncated_sorted(vals_sorted, strategy)) == 1:4 end From f7db753563e337599edaea798a3427aaccb62092 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 21 Sep 2025 11:16:22 +0200 Subject: [PATCH 05/17] update docs --- docs/src/user_interface/truncations.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/user_interface/truncations.md b/docs/src/user_interface/truncations.md index b108470cb..407998908 100644 --- a/docs/src/user_interface/truncations.md +++ b/docs/src/user_interface/truncations.md @@ -11,7 +11,7 @@ Currently, truncations are supported through the following different methods: notrunc truncrank trunctol -truncabove +truncfilter truncerror ``` @@ -20,6 +20,6 @@ For example, truncating to a maximal dimension `10`, and discarding all values b ```julia maxdim = 10 -tol = 1e-6 -combined_trunc = truncrank(maxdim) & trunctol(tol) +atol = 1e-6 +combined_trunc = truncrank(maxdim) & trunctol(; atol) ``` From d72ce5ac757cb28869c8322b63e4ef7bad9e80ae Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 21 Sep 2025 11:54:49 +0200 Subject: [PATCH 06/17] update GPU tests --- test/amd/eigh.jl | 2 +- test/amd/svd.jl | 4 ++-- test/cuda/eig.jl | 2 +- test/cuda/eigh.jl | 2 +- test/cuda/svd.jl | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/amd/eigh.jl b/test/amd/eigh.jl index 44be84952..ed37b20f2 100644 --- a/test/amd/eigh.jl +++ b/test/amd/eigh.jl @@ -51,7 +51,7 @@ end @test A * V1 ≈ V1 * D1 @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] - trunc = trunctol(s * D₀[r + 1]) + trunc = trunctol(; atol=s * D₀[r + 1]) D2, V2 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test isisometry(V2) diff --git a/test/amd/svd.jl b/test/amd/svd.jl index ecf79b003..e8acb8471 100644 --- a/test/amd/svd.jl +++ b/test/amd/svd.jl @@ -108,9 +108,9 @@ end # @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] # s = 1 + sqrt(eps(real(T))) -# trunc2 = trunctol(s * S₀[r + 1]) +# trunc2 = trunctol(; atol=s * S₀[r + 1]) -# U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc=trunctol(s * S₀[r + 1])) +# U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc=trunctol(; atol=s * S₀[r + 1])) # @test length(S2.diag) == r # @test U1 ≈ U2 # @test S1 ≈ S2 diff --git a/test/cuda/eig.jl b/test/cuda/eig.jl index 2ec8da223..e83cb3fd7 100644 --- a/test/cuda/eig.jl +++ b/test/cuda/eig.jl @@ -49,7 +49,7 @@ end @test A * V1 ≈ V1 * D1 s = 1 + sqrt(eps(real(T))) - trunc = trunctol(s * abs(D₀[r + 1])) + trunc = trunctol(; atol=s * abs(D₀[r + 1])) D2, V2 = @constinferred eig_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test A * V2 ≈ V2 * D2 diff --git a/test/cuda/eigh.jl b/test/cuda/eigh.jl index c15bbb12e..a439865c3 100644 --- a/test/cuda/eigh.jl +++ b/test/cuda/eigh.jl @@ -49,7 +49,7 @@ end @test A * V1 ≈ V1 * D1 @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] - trunc = trunctol(s * D₀[r + 1]) + trunc = trunctol(; atol = s * D₀[r + 1]) D2, V2 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test isisometry(V2) diff --git a/test/cuda/svd.jl b/test/cuda/svd.jl index 8765ff08d..1b59b21ff 100644 --- a/test/cuda/svd.jl +++ b/test/cuda/svd.jl @@ -109,9 +109,9 @@ end if !(alg isa CUSOLVER_Randomized) s = 1 + sqrt(eps(real(T))) - trunc2 = trunctol(s * S₀[r + 1]) + trunc2 = trunctol(; atol=s * S₀[r + 1]) - U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc=trunctol(s * S₀[r + 1])) + U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc=trunctol(; atol=s * S₀[r + 1])) @test length(S2.diag) == r @test U1 ≈ U2 @test parent(S1) ≈ parent(S2) From 24c3759746e8ef6fc510f6694be7ba8b1a6c84ff Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 22 Sep 2025 14:53:15 +0200 Subject: [PATCH 07/17] Consistency in docstrings --- src/interface/truncation.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/interface/truncation.jl b/src/interface/truncation.jl index 9855ae542..113ca30e5 100644 --- a/src/interface/truncation.jl +++ b/src/interface/truncation.jl @@ -125,7 +125,7 @@ end """ truncerror(; atol::Real=0, rtol::Real=0, p::Real=2) -Create a truncation strategy for truncating such that the error in the factorization +Truncation strategy for truncating values such that the error in the factorization is smaller than `max(atol, rtol * norm)`, where the error is determined using the `p`-norm. """ function truncerror(; atol::Real=0, rtol::Real=0, p::Real=2) @@ -135,7 +135,8 @@ end """ TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...) -Composition of multiple truncation strategies, keeping values common between them. +Truncation strategy that composes multiple truncation strategies, keeping values that are +common between them. """ struct TruncationIntersection{T<:Tuple{Vararg{TruncationStrategy}}} <: TruncationStrategy components::T From cb0b0ed87dfd5a5e2d250b885b19b554ee7561ff Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 22 Sep 2025 14:57:27 +0200 Subject: [PATCH 08/17] fix TruncationByOrder sorted and rev strategy --- src/implementations/truncation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 5942b6ef9..c088690a1 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -42,7 +42,7 @@ function findtruncated(values::AbstractVector, strategy::TruncationByOrder) end function findtruncated_sorted(values::AbstractVector, strategy::TruncationByOrder) howmany = min(strategy.howmany, length(values)) - return 1:howmany + return strategy.rev ? (1:howmany) : ((length(values) - howmany + 1):length(values)) end function findtruncated(values::AbstractVector, strategy::TruncationByFilter) From b8e6bb71b1eccfd4eb7d62c03807cdf637f1df89 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 22 Sep 2025 15:02:36 +0200 Subject: [PATCH 09/17] change meaning of rev --- src/implementations/orthnull.jl | 2 +- src/implementations/truncation.jl | 8 ++++---- src/interface/truncation.jl | 14 +++++++------- test/algorithms.jl | 2 +- test/orthnull.jl | 2 +- test/truncate.jl | 8 ++++---- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 693a3ab9c..cde2a9662 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -207,7 +207,7 @@ function null_truncation_strategy(; atol=nothing, rtol=nothing, maxnullity=nothi end atol = @something atol 0 rtol = @something rtol 0 - trunc = trunctol(; atol, rtol, rev=false) + trunc = trunctol(; atol, rtol, rev=true) return !isnothing(maxnullity) ? trunc & truncrank(maxnullity; rev=false) : trunc end diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index c088690a1..a26128927 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -52,18 +52,18 @@ end function findtruncated(values::AbstractVector, strategy::TruncationByValue) atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) - filter = (strategy.rev ? ≥(atol) : ≤(atol)) ∘ strategy.by + filter = (strategy.rev ? ≤(atol) : ≥(atol)) ∘ strategy.by return findall(filter, values) end function findtruncated_sorted(values::AbstractVector, strategy::TruncationByValue) atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) @assert strategy.by === abs || strategy.by === real "sorting strategy incompatible with implementation" if strategy.rev - i = searchsortedlast(values, atol; by=strategy.by, rev=true) - return 1:i - else i = searchsortedfirst(values, atol; by=strategy.by, rev=true) return i:length(values) + else + i = searchsortedlast(values, atol; by=strategy.by, rev=true) + return 1:i end end diff --git a/src/interface/truncation.jl b/src/interface/truncation.jl index 113ca30e5..2c35316f4 100644 --- a/src/interface/truncation.jl +++ b/src/interface/truncation.jl @@ -80,10 +80,10 @@ Truncation strategy to keep the values for which `filter` returns true. truncfilter(f) = TruncationByFilter(f) """ - TruncationByValue(atol::Real, rtol::Real, p::Real, by, rev::Bool=true) + TruncationByValue(atol::Real, rtol::Real, p::Real, by, rev::Bool=false) -Truncation strategy to keep the values that satisfy `by(val) < max(atol, rtol * norm(values, p)` -if `rev = true`, or discard them when `rev = false`. +Truncation strategy to keep the values that satisfy `by(val) > max(atol, rtol * norm(values, p)` +if `rev = false`, or discard them when `rev = true`. See also [`trunctol`](@ref) """ struct TruncationByValue{T<:Real,P<:Real,F} <: TruncationStrategy @@ -98,12 +98,12 @@ function TruncationByValue(atol::Real, rtol::Real, p::Real=2, by=abs, rev::Bool= end """ - trunctol(; atol::Real=0, rtol::Real=0, p::Real=2, by=abs, ) + trunctol(; atol::Real=0, rtol::Real=0, p::Real=2, by=abs, rev::Bool=false) -Truncation strategy to keep the values that satisfy `by(val) < max(atol, rtol * norm(values, p)` -if `rev = true`, or discard them when `rev = false`. +Truncation strategy to keep the values that satisfy `by(val) > max(atol, rtol * norm(values, p)` +if `rev = false`, or discard them when `rev = true`. """ -function trunctol(; atol::Real=0, rtol::Real=0, p::Real=2, by=abs, rev::Bool=true) +function trunctol(; atol::Real=0, rtol::Real=0, p::Real=2, by=abs, rev::Bool=false) return TruncationByValue(atol, rtol, p, by, rev) end diff --git a/test/algorithms.jl b/test/algorithms.jl index 7648384d2..fdb4c9e2d 100644 --- a/test/algorithms.jl +++ b/test/algorithms.jl @@ -50,7 +50,7 @@ end notrunc()) end - alg = TruncatedAlgorithm(LAPACK_Simple(), trunctol(; atol=0.1, rev=false)) + alg = TruncatedAlgorithm(LAPACK_Simple(), trunctol(; atol=0.1, rev=true)) for f in (eig_trunc!, eigh_trunc!, svd_trunc!) @test @constinferred(select_algorithm(eig_trunc!, A, alg)) === alg @test_throws ArgumentError select_algorithm(eig_trunc!, A, alg; trunc=(; maxrank=2)) diff --git a/test/orthnull.jl b/test/orthnull.jl index 5a1cdf059..04d8c94c4 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -125,7 +125,7 @@ end rtol = eps(real(T)) for (trunc_orth, trunc_null) in (((; rtol=rtol), (; rtol=rtol)), - (trunctol(; rtol), trunctol(; rtol, rev=false))) + (trunctol(; rtol), trunctol(; rtol, rev=true))) V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=trunc_orth) N2 = @constinferred left_null!(copy!(Ac, A), N; trunc=trunc_null) @test V2 !== V diff --git a/test/truncate.jl b/test/truncate.jl index 266515f58..0d3391b73 100644 --- a/test/truncate.jl +++ b/test/truncate.jl @@ -18,7 +18,7 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationByOrder, @test trunc == trunctol(; atol, rtol) @test trunc.atol == atol @test trunc.rtol == rtol - @test trunc.rev + @test !trunc.rev trunc = @constinferred TruncationStrategy(; maxrank) @test trunc isa TruncationByOrder @@ -41,7 +41,7 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationByOrder, strategy = trunctol(; atol=0.4) @test @constinferred(findtruncated(values, strategy)) == 1:3 @test @constinferred(findtruncated_sorted(values, strategy)) === 1:3 - strategy = trunctol(; atol=0.4, rev=false) + strategy = trunctol(; atol=0.4, rev=true) @test @constinferred(findtruncated(values, strategy)) == 4:5 @test @constinferred(findtruncated_sorted(values, strategy)) === 4:5 @@ -53,10 +53,10 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationByOrder, @test @constinferred(findtruncated(values, strategy)) == [2, 3, 4, 5] for strategy in - (trunctol(; atol=0.4, rev=false), trunctol(; atol=0.2, by=identity, rev=false)) + (trunctol(; atol=0.4, rev=true), trunctol(; atol=0.2, by=identity, rev=true)) @test @constinferred(findtruncated(values, strategy)) == [1, 4] end - strategy = trunctol(; atol=0.2, rev=false) + strategy = trunctol(; atol=0.2, rev=true) @test @constinferred(findtruncated(values, strategy)) == [1] strategy = truncfilter(x -> 0.1 < x < 1) From 6f994a4b42f3feb52bed0690d5b71cabb55e3f85 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 23 Sep 2025 08:12:53 -0400 Subject: [PATCH 10/17] Attempt to fix GPU implementation --- ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl | 7 ++++++- ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl | 6 +++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 7f51bde93..8efe97fa1 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -4,7 +4,7 @@ using MatrixAlgebraKit using MatrixAlgebraKit: @algdef, Algorithm, check_input using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! using MatrixAlgebraKit: diagview, sign_safe -using MatrixAlgebraKit: LQViaTransposedQR +using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj! import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx! @@ -40,4 +40,9 @@ _gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwar _gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevd!(A, Dd, V; kwargs...) _gpu_heev!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heev!(A, Dd, V; kwargs...) _gpu_heevx!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevx!(A, Dd, V; kwargs...) + +function MatrixAlgebraKit.findtruncated_sorted(values::StridedROCVector, strategy::TruncationByValue) + return MatrixAlgebraKit.findtruncated(values, strategy) +end + end diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 1fafb269d..fa7488302 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -4,7 +4,7 @@ using MatrixAlgebraKit using MatrixAlgebraKit: @algdef, Algorithm, check_input using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! using MatrixAlgebraKit: diagview, sign_safe -using MatrixAlgebraKit: LQViaTransposedQR +using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev! import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd! @@ -44,4 +44,8 @@ _gpu_gesvdj!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::S _gpu_heevj!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevj!(A, Dd, V; kwargs...) _gpu_heevd!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevd!(A, Dd, V; kwargs...) +function MatrixAlgebraKit.findtruncated_sorted(values::StridedCuVector, strategy::TruncationByValue) + return MatrixAlgebraKit.findtruncated(values, strategy) +end + end From 852774a67cb45e558862c8482f1c28b4cd130375 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 23 Sep 2025 20:25:48 -0400 Subject: [PATCH 11/17] fixes and `keep_below` --- src/implementations/orthnull.jl | 4 +-- src/implementations/truncation.jl | 49 +++++++++++++++++++++---------- src/interface/truncation.jl | 27 +++++++++-------- test/algorithms.jl | 2 +- test/orthnull.jl | 2 +- test/truncate.jl | 26 ++++++++-------- 6 files changed, 65 insertions(+), 45 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index cde2a9662..df56302d6 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -203,11 +203,11 @@ end # -------------------------------- function null_truncation_strategy(; atol=nothing, rtol=nothing, maxnullity=nothing) if isnothing(maxnullity) && isnothing(atol) && isnothing(rtol) - return NoTruncation() + return notrunc() end atol = @something atol 0 rtol = @something rtol 0 - trunc = trunctol(; atol, rtol, rev=true) + trunc = trunctol(; atol, rtol, keep_below=true) return !isnothing(maxnullity) ? trunc & truncrank(maxnullity; rev=false) : trunc end diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index a26128927..7f7a8183f 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -40,29 +40,36 @@ function findtruncated(values::AbstractVector, strategy::TruncationByOrder) howmany = min(strategy.howmany, length(values)) return partialsortperm(values, 1:howmany; strategy.by, strategy.rev) end -function findtruncated_sorted(values::AbstractVector, strategy::TruncationByOrder) - howmany = min(strategy.howmany, length(values)) - return strategy.rev ? (1:howmany) : ((length(values) - howmany + 1):length(values)) +function findtruncated_svd(values::AbstractVector, strategy::TruncationByOrder) + if strategy.by === abs + howmany = min(strategy.howmany, length(values)) + return strategy.rev ? (1:howmany) : ((length(values) - howmany + 1):length(values)) + else + return findtruncated(values, strategy) + end end function findtruncated(values::AbstractVector, strategy::TruncationByFilter) - ind = findall(strategy.filter, values) - return ind + # pre-allocate bitvector to enforce the filter function returns a Bool + mask = similar(BitArray, eachindex(values)) + mask .= strategy.filter.(values) + return mask end function findtruncated(values::AbstractVector, strategy::TruncationByValue) atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) - filter = (strategy.rev ? ≤(atol) : ≥(atol)) ∘ strategy.by - return findall(filter, values) + filter = (strategy.keep_below ? ≤(atol) : ≥(atol)) ∘ strategy.by + return findtruncated(values, truncfilter(filter)) end -function findtruncated_sorted(values::AbstractVector, strategy::TruncationByValue) +function findtruncated_svd(values::AbstractVector, strategy::TruncationByValue) + strategy.by === abs || return findtruncated(values, strategy) + atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) - @assert strategy.by === abs || strategy.by === real "sorting strategy incompatible with implementation" - if strategy.rev - i = searchsortedfirst(values, atol; by=strategy.by, rev=true) + if strategy.keep_below + i = searchsortedfirst(values, atol; by=abs, rev=true) return i:length(values) else - i = searchsortedlast(values, atol; by=strategy.by, rev=true) + i = searchsortedlast(values, atol; by=abs, rev=true) return 1:i end end @@ -97,10 +104,20 @@ function _truncerr_impl(values::AbstractVector, I; atol::Real=0, rtol::Real=0, p end function findtruncated(values::AbstractVector, strategy::TruncationIntersection) - inds = map(Base.Fix1(findtruncated, values), strategy.components) - return intersect(inds...) + return mapreduce(Base.Fix1(findtruncated, values), _ind_intersect, strategy.components; + init=trues(length(values))) end function findtruncated_sorted(values::AbstractVector, strategy::TruncationIntersection) - inds = map(Base.Fix1(findtruncated_sorted, values), strategy.components) - return intersect(inds...) + return mapreduce(Base.Fix1(findtruncated_sorted, values), _ind_intersect, + strategy.components; init=trues(length(values))) end + +# when one of the ind selections is a bitvector, have to handle differently +function _ind_intersect(A::AbstractVector{Bool}, B::AbstractVector) + result = falses(length(A)) + result[B] .= @view A[B] + return result +end +_ind_intersect(A::AbstractVector, B::AbstractVector{Bool}) = _ind_intersect(B, A) +_ind_intersect(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .& B +_ind_intersect(A, B) = intersect(A, B) diff --git a/src/interface/truncation.jl b/src/interface/truncation.jl index 2c35316f4..e06536d77 100644 --- a/src/interface/truncation.jl +++ b/src/interface/truncation.jl @@ -41,6 +41,7 @@ Truncation strategy that does nothing, and keeps all the values. """ notrunc() = NoTruncation() +# TODO: Base.Ordering? """ TruncationByOrder(howmany::Int, by::Function, rev::Bool) @@ -59,7 +60,9 @@ end Truncation strategy to keep the first `howmany` values when sorted according to `by` or the last `howmany` if `rev` is true. """ -truncrank(howmany::Integer; by=abs, rev::Bool=true) = TruncationByOrder(howmany, by, rev) +function truncrank(howmany::Integer; by=abs, rev::Bool=true) + return TruncationByOrder(howmany, by, rev) +end """ TruncationByFilter(filter::Function) @@ -80,10 +83,10 @@ Truncation strategy to keep the values for which `filter` returns true. truncfilter(f) = TruncationByFilter(f) """ - TruncationByValue(atol::Real, rtol::Real, p::Real, by, rev::Bool=false) + TruncationByValue(atol::Real, rtol::Real, p::Real, by, keep_below::Bool=false) -Truncation strategy to keep the values that satisfy `by(val) > max(atol, rtol * norm(values, p)` -if `rev = false`, or discard them when `rev = true`. +Truncation strategy to keep the values that satisfy `by(val) > max(atol, rtol * norm(values, p)`. +If `keep_below = true`, discard these values instead. See also [`trunctol`](@ref) """ struct TruncationByValue{T<:Real,P<:Real,F} <: TruncationStrategy @@ -91,20 +94,20 @@ struct TruncationByValue{T<:Real,P<:Real,F} <: TruncationStrategy rtol::T p::P by::F - rev::Bool + keep_below::Bool end -function TruncationByValue(atol::Real, rtol::Real, p::Real=2, by=abs, rev::Bool=true) - return TruncationByValue(promote(atol, rtol)..., p, by, rev) +function TruncationByValue(atol::Real, rtol::Real, p::Real=2, by=abs, keep_below::Bool=true) + return TruncationByValue(promote(atol, rtol)..., p, by, keep_below) end """ - trunctol(; atol::Real=0, rtol::Real=0, p::Real=2, by=abs, rev::Bool=false) + trunctol(; atol::Real=0, rtol::Real=0, p::Real=2, by=abs, keep_below::Bool=false) -Truncation strategy to keep the values that satisfy `by(val) > max(atol, rtol * norm(values, p)` -if `rev = false`, or discard them when `rev = true`. +Truncation strategy to keep the values that satisfy `by(val) > max(atol, rtol * norm(values, p)`. +If `keep_below = true`, discard these values instead. """ -function trunctol(; atol::Real=0, rtol::Real=0, p::Real=2, by=abs, rev::Bool=false) - return TruncationByValue(atol, rtol, p, by, rev) +function trunctol(; atol::Real=0, rtol::Real=0, p::Real=2, by=abs, keep_below::Bool=false) + return TruncationByValue(atol, rtol, p, by, keep_below) end """ diff --git a/test/algorithms.jl b/test/algorithms.jl index fdb4c9e2d..7e98e1f13 100644 --- a/test/algorithms.jl +++ b/test/algorithms.jl @@ -50,7 +50,7 @@ end notrunc()) end - alg = TruncatedAlgorithm(LAPACK_Simple(), trunctol(; atol=0.1, rev=true)) + alg = TruncatedAlgorithm(LAPACK_Simple(), trunctol(; atol=0.1, keep_below=true)) for f in (eig_trunc!, eigh_trunc!, svd_trunc!) @test @constinferred(select_algorithm(eig_trunc!, A, alg)) === alg @test_throws ArgumentError select_algorithm(eig_trunc!, A, alg; trunc=(; maxrank=2)) diff --git a/test/orthnull.jl b/test/orthnull.jl index 04d8c94c4..d7cb4ff57 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -125,7 +125,7 @@ end rtol = eps(real(T)) for (trunc_orth, trunc_null) in (((; rtol=rtol), (; rtol=rtol)), - (trunctol(; rtol), trunctol(; rtol, rev=true))) + (trunctol(; rtol), trunctol(; rtol, keep_below=true))) V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=trunc_orth) N2 = @constinferred left_null!(copy!(Ac, A), N; trunc=trunc_null) @test V2 !== V diff --git a/test/truncate.jl b/test/truncate.jl index 0d3391b73..663597477 100644 --- a/test/truncate.jl +++ b/test/truncate.jl @@ -18,7 +18,7 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationByOrder, @test trunc == trunctol(; atol, rtol) @test trunc.atol == atol @test trunc.rtol == rtol - @test !trunc.rev + @test !trunc.keep_below trunc = @constinferred TruncationStrategy(; maxrank) @test trunc isa TruncationByOrder @@ -39,28 +39,28 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationByOrder, values = [1, 0.9, 0.5, -0.3, 0.01] strategy = trunctol(; atol=0.4) - @test @constinferred(findtruncated(values, strategy)) == 1:3 - @test @constinferred(findtruncated_sorted(values, strategy)) === 1:3 - strategy = trunctol(; atol=0.4, rev=true) - @test @constinferred(findtruncated(values, strategy)) == 4:5 - @test @constinferred(findtruncated_sorted(values, strategy)) === 4:5 + @test findall(@constinferred(findtruncated(values, strategy))) == 1:3 + @test @constinferred(findtruncated_svd(values, strategy)) === 1:3 + strategy = trunctol(; atol=0.4, keep_below=true) + @test findall(@constinferred(findtruncated(values, strategy))) == 4:5 + @test @constinferred(findtruncated_svd(values, strategy)) === 4:5 values = [0.01, 1, 0.9, -0.3, 0.5] for strategy in (trunctol(; atol=0.4), trunctol(; atol=0.2, by=identity)) - @test @constinferred(findtruncated(values, strategy)) == [2, 3, 5] + @test findall(@constinferred(findtruncated(values, strategy))) == [2, 3, 5] end strategy = trunctol(; atol=0.2) - @test @constinferred(findtruncated(values, strategy)) == [2, 3, 4, 5] + @test findall(@constinferred(findtruncated(values, strategy))) == [2, 3, 4, 5] for strategy in - (trunctol(; atol=0.4, rev=true), trunctol(; atol=0.2, by=identity, rev=true)) - @test @constinferred(findtruncated(values, strategy)) == [1, 4] + (trunctol(; atol=0.4, keep_below=true), trunctol(; atol=0.2, by=identity, keep_below=true)) + @test findall(@constinferred(findtruncated(values, strategy))) == [1, 4] end - strategy = trunctol(; atol=0.2, rev=true) - @test @constinferred(findtruncated(values, strategy)) == [1] + strategy = trunctol(; atol=0.2, keep_below=true) + @test findall(@constinferred(findtruncated(values, strategy))) == [1] strategy = truncfilter(x -> 0.1 < x < 1) - @test @constinferred(findtruncated(values, strategy)) == [3, 5] + @test findall(@constinferred(findtruncated(values, strategy))) == [3, 5] strategy = truncerror(; atol=0.2, rtol=0) @test issetequal(@constinferred(findtruncated(values, strategy)), 2:5) From b5d60ed355f1fc1ed693b381675bf522d8fd1f6a Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 23 Sep 2025 21:55:34 -0400 Subject: [PATCH 12/17] rename `findtruncated_sorted` to `findtruncated_svd` --- docs/src/dev_interface.md | 2 +- .../MatrixAlgebraKitAMDGPUExt.jl | 2 +- ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl | 2 +- src/MatrixAlgebraKit.jl | 2 +- src/algorithms.jl | 6 +++--- src/implementations/truncation.jl | 10 +++++----- test/truncate.jl | 6 +++--- 7 files changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/src/dev_interface.md b/docs/src/dev_interface.md index 52d44ca13..bb722c498 100644 --- a/docs/src/dev_interface.md +++ b/docs/src/dev_interface.md @@ -11,5 +11,5 @@ MatrixAlgebraKit.jl provides a developer interface for specifying custom algorit MatrixAlgebraKit.default_algorithm MatrixAlgebraKit.select_algorithm MatrixAlgebraKit.findtruncated -MatrixAlgebraKit.findtruncated_sorted +MatrixAlgebraKit.findtruncated_svd ``` diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 8efe97fa1..8b14f5700 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -41,7 +41,7 @@ _gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwar _gpu_heev!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heev!(A, Dd, V; kwargs...) _gpu_heevx!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevx!(A, Dd, V; kwargs...) -function MatrixAlgebraKit.findtruncated_sorted(values::StridedROCVector, strategy::TruncationByValue) +function MatrixAlgebraKit.findtruncated_svd(values::StridedROCVector, strategy::TruncationByValue) return MatrixAlgebraKit.findtruncated(values, strategy) end diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index fa7488302..20ea9496a 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -44,7 +44,7 @@ _gpu_gesvdj!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::S _gpu_heevj!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevj!(A, Dd, V; kwargs...) _gpu_heevd!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevd!(A, Dd, V; kwargs...) -function MatrixAlgebraKit.findtruncated_sorted(values::StridedCuVector, strategy::TruncationByValue) +function MatrixAlgebraKit.findtruncated_svd(values::StridedCuVector, strategy::TruncationByValue) return MatrixAlgebraKit.findtruncated(values, strategy) end diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index a88c1747e..da7c2af99 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -38,7 +38,7 @@ export ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi, export notrunc, truncrank, trunctol, truncerror, truncfilter VERSION >= v"1.11.0-DEV.469" && - eval(Expr(:public, :default_algorithm, :findtruncated, :findtruncated_sorted, + eval(Expr(:public, :default_algorithm, :findtruncated, :findtruncated_svd, :select_algorithm)) include("common/defaults.jl") diff --git a/src/algorithms.jl b/src/algorithms.jl index abce4097f..81d2d1ac3 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -168,16 +168,16 @@ based on the `strategy`. The output should be a collection of indices specifying which values to keep. `MatrixAlgebraKit.findtruncated` is used inside of the default implementation of [`truncate!`](@ref) to perform the truncation. It does not assume that the values are sorted. For a version that assumes the values are reverse sorted (which is the -standard case for SVD) see [`MatrixAlgebraKit.findtruncated_sorted`](@ref). +standard case for SVD) see [`MatrixAlgebraKit.findtruncated_svd`](@ref). """ findtruncated @doc """ - MatrixAlgebraKit.findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy) + MatrixAlgebraKit.findtruncated_svd(values::AbstractVector, strategy::TruncationStrategy) Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are real and sorted in descending order, as typically obtained by the SVD. This assumption is not checked, and this is used in the default implementation of [`svd_trunc!`](@ref). -""" findtruncated_sorted +""" findtruncated_svd """ TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm) diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 7f7a8183f..aa9d720dd 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -2,7 +2,7 @@ # --------- # Generic implementation: `findtruncated` followed by indexing function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy) - ind = findtruncated_sorted(diagview(S), strategy) + ind = findtruncated_svd(diagview(S), strategy) return U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :] end function truncate!(::typeof(eig_trunc!), (D, V), strategy::TruncationStrategy) @@ -29,7 +29,7 @@ end # findtruncated # ------------- # Generic fallback -function findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy) +function findtruncated_svd(values::AbstractVector, strategy::TruncationStrategy) return findtruncated(values, strategy) end @@ -79,7 +79,7 @@ function findtruncated(values::AbstractVector, strategy::TruncationByError) I′ = _truncerr_impl(values, I; strategy.atol, strategy.rtol, strategy.p) return I[I′] end -function findtruncated_sorted(values::AbstractVector, strategy::TruncationByError) +function findtruncated_svd(values::AbstractVector, strategy::TruncationByError) I = eachindex(values) I′ = _truncerr_impl(values, I; strategy.atol, strategy.rtol, strategy.p) return I[I′] @@ -107,8 +107,8 @@ function findtruncated(values::AbstractVector, strategy::TruncationIntersection) return mapreduce(Base.Fix1(findtruncated, values), _ind_intersect, strategy.components; init=trues(length(values))) end -function findtruncated_sorted(values::AbstractVector, strategy::TruncationIntersection) - return mapreduce(Base.Fix1(findtruncated_sorted, values), _ind_intersect, +function findtruncated_svd(values::AbstractVector, strategy::TruncationIntersection) + return mapreduce(Base.Fix1(findtruncated_svd, values), _ind_intersect, strategy.components; init=trues(length(values))) end diff --git a/test/truncate.jl b/test/truncate.jl index 663597477..2c13c2eb6 100644 --- a/test/truncate.jl +++ b/test/truncate.jl @@ -3,7 +3,7 @@ using Test using TestExtras using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationByOrder, TruncationByValue, TruncationStrategy, findtruncated, - findtruncated_sorted + findtruncated_svd @testset "truncate" begin trunc = @constinferred TruncationStrategy() @@ -35,7 +35,7 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationByOrder, @test @constinferred(findtruncated(values, truncrank(2))) == 1:2 @test @constinferred(findtruncated(values, truncrank(2; rev=false))) == [5, 4] @test @constinferred(findtruncated(values, truncrank(2; by=((-) ∘ abs)))) == [5, 4] - @test @constinferred(findtruncated_sorted(values, truncrank(2))) === 1:2 + @test @constinferred(findtruncated_svd(values, truncrank(2))) === 1:2 values = [1, 0.9, 0.5, -0.3, 0.01] strategy = trunctol(; atol=0.4) @@ -65,5 +65,5 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationByOrder, strategy = truncerror(; atol=0.2, rtol=0) @test issetequal(@constinferred(findtruncated(values, strategy)), 2:5) vals_sorted = sort(values; by=abs, rev=true) - @test @constinferred(findtruncated_sorted(vals_sorted, strategy)) == 1:4 + @test @constinferred(findtruncated_svd(vals_sorted, strategy)) == 1:4 end From 5e5a532813fe51e828773e4ca0f550a969bff2c5 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 23 Sep 2025 21:59:26 -0400 Subject: [PATCH 13/17] make truncation strategies public to allow specializing --- src/MatrixAlgebraKit.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index da7c2af99..d29e16f28 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -37,9 +37,12 @@ export ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi, export notrunc, truncrank, trunctol, truncerror, truncfilter -VERSION >= v"1.11.0-DEV.469" && +@static if VERSION >= v"1.11.0-DEV.469" eval(Expr(:public, :default_algorithm, :findtruncated, :findtruncated_svd, :select_algorithm)) + eval(Expr(:public, :TruncationByOrder, :TruncationByFilter, :TruncationByValue, + :TruncationByError, :TruncationIntersection)) +end include("common/defaults.jl") include("common/initialization.jl") From fdb36574ac7a9584b74403efaed171f418b5614b Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 24 Sep 2025 06:16:21 -0400 Subject: [PATCH 14/17] make bitarray more GPU friendly --- src/implementations/truncation.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index aa9d720dd..2371612bf 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -50,10 +50,7 @@ function findtruncated_svd(values::AbstractVector, strategy::TruncationByOrder) end function findtruncated(values::AbstractVector, strategy::TruncationByFilter) - # pre-allocate bitvector to enforce the filter function returns a Bool - mask = similar(BitArray, eachindex(values)) - mask .= strategy.filter.(values) - return mask + return strategy.filter.(values)::AbstractVector{Bool} end function findtruncated(values::AbstractVector, strategy::TruncationByValue) From 5d720a38e96f20cc90a09007c68da668c90ffde8 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 24 Sep 2025 10:14:21 -0400 Subject: [PATCH 15/17] small formatting changes --- src/implementations/truncation.jl | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 2371612bf..d6f57aaac 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -41,12 +41,9 @@ function findtruncated(values::AbstractVector, strategy::TruncationByOrder) return partialsortperm(values, 1:howmany; strategy.by, strategy.rev) end function findtruncated_svd(values::AbstractVector, strategy::TruncationByOrder) - if strategy.by === abs - howmany = min(strategy.howmany, length(values)) - return strategy.rev ? (1:howmany) : ((length(values) - howmany + 1):length(values)) - else - return findtruncated(values, strategy) - end + strategy.by === abs || return findtruncated(values, strategy) + howmany = min(strategy.howmany, length(values)) + return strategy.rev ? (1:howmany) : ((length(values) - howmany + 1):length(values)) end function findtruncated(values::AbstractVector, strategy::TruncationByFilter) @@ -60,7 +57,6 @@ function findtruncated(values::AbstractVector, strategy::TruncationByValue) end function findtruncated_svd(values::AbstractVector, strategy::TruncationByValue) strategy.by === abs || return findtruncated(values, strategy) - atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) if strategy.keep_below i = searchsortedfirst(values, atol; by=abs, rev=true) From 719948665bf5a639507c30bcfac7727b375dd41a Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 25 Sep 2025 13:16:47 -0400 Subject: [PATCH 16/17] change back to findall --- src/implementations/truncation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index d6f57aaac..00fcabb96 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -47,7 +47,7 @@ function findtruncated_svd(values::AbstractVector, strategy::TruncationByOrder) end function findtruncated(values::AbstractVector, strategy::TruncationByFilter) - return strategy.filter.(values)::AbstractVector{Bool} + return findall(strategy.filter, values) end function findtruncated(values::AbstractVector, strategy::TruncationByValue) From 77670669f1c3e5f15cf4402d7b9fbe63f2625162 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 25 Sep 2025 13:26:32 -0400 Subject: [PATCH 17/17] future-proof the truncation tests --- test/truncate.jl | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/test/truncate.jl b/test/truncate.jl index 2c13c2eb6..0fc28938c 100644 --- a/test/truncate.jl +++ b/test/truncate.jl @@ -32,38 +32,38 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationByOrder, @test trunc == truncrank(maxrank) & trunctol(; atol, rtol) values = [1, 0.9, 0.5, -0.3, 0.01] - @test @constinferred(findtruncated(values, truncrank(2))) == 1:2 - @test @constinferred(findtruncated(values, truncrank(2; rev=false))) == [5, 4] - @test @constinferred(findtruncated(values, truncrank(2; by=((-) ∘ abs)))) == [5, 4] - @test @constinferred(findtruncated_svd(values, truncrank(2))) === 1:2 + @test values[@constinferred(findtruncated(values, truncrank(2)))] == values[1:2] + @test values[@constinferred(findtruncated(values, truncrank(2; rev=false)))] == values[[5, 4]] + @test values[@constinferred(findtruncated(values, truncrank(2; by=((-) ∘ abs))))] == values[[5, 4]] + @test values[@constinferred(findtruncated_svd(values, truncrank(2)))] == values[1:2] values = [1, 0.9, 0.5, -0.3, 0.01] strategy = trunctol(; atol=0.4) - @test findall(@constinferred(findtruncated(values, strategy))) == 1:3 - @test @constinferred(findtruncated_svd(values, strategy)) === 1:3 + @test values[@constinferred(findtruncated(values, strategy))] == values[1:3] + @test values[@constinferred(findtruncated_svd(values, strategy))] == values[1:3] strategy = trunctol(; atol=0.4, keep_below=true) - @test findall(@constinferred(findtruncated(values, strategy))) == 4:5 - @test @constinferred(findtruncated_svd(values, strategy)) === 4:5 + @test values[@constinferred(findtruncated(values, strategy))] == values[4:5] + @test values[@constinferred(findtruncated_svd(values, strategy))] == values[4:5] values = [0.01, 1, 0.9, -0.3, 0.5] for strategy in (trunctol(; atol=0.4), trunctol(; atol=0.2, by=identity)) - @test findall(@constinferred(findtruncated(values, strategy))) == [2, 3, 5] + @test values[@constinferred(findtruncated(values, strategy))] == values[[2, 3, 5]] end strategy = trunctol(; atol=0.2) - @test findall(@constinferred(findtruncated(values, strategy))) == [2, 3, 4, 5] + @test values[@constinferred(findtruncated(values, strategy))] == values[[2, 3, 4, 5]] for strategy in (trunctol(; atol=0.4, keep_below=true), trunctol(; atol=0.2, by=identity, keep_below=true)) - @test findall(@constinferred(findtruncated(values, strategy))) == [1, 4] + @test values[@constinferred(findtruncated(values, strategy))] == values[[1, 4]] end strategy = trunctol(; atol=0.2, keep_below=true) - @test findall(@constinferred(findtruncated(values, strategy))) == [1] + @test values[@constinferred(findtruncated(values, strategy))] == values[[1]] strategy = truncfilter(x -> 0.1 < x < 1) - @test findall(@constinferred(findtruncated(values, strategy))) == [3, 5] + @test values[@constinferred(findtruncated(values, strategy))] == values[[3, 5]] strategy = truncerror(; atol=0.2, rtol=0) - @test issetequal(@constinferred(findtruncated(values, strategy)), 2:5) + @test issetequal(values[@constinferred(findtruncated(values, strategy))], values[2:5]) vals_sorted = sort(values; by=abs, rev=true) - @test @constinferred(findtruncated_svd(vals_sorted, strategy)) == 1:4 + @test vals_sorted[@constinferred(findtruncated_svd(vals_sorted, strategy))] == vals_sorted[1:4] end