Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/src/user_interface/truncations.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,16 @@ combined_trunc = truncrank(10) & trunctol(; atol = 1e-6);

## Truncation Error

When using truncated decompositions such as [`svd_trunc`](@ref), [`eig_trunc`](@ref), or [`eigh_trunc`](@ref), an additional truncation error value is returned.
When using truncated decompositions such as [`svd_trunc_with_err`](@ref), [`eig_trunc`](@ref), or [`eigh_trunc`](@ref), an additional truncation error value is returned.
This error is defined as the 2-norm of the discarded singular values or eigenvalues, providing a measure of the approximation quality.
For `svd_trunc` and `eigh_trunc`, this corresponds to the 2-norm difference between the original and the truncated matrix.
For `svd_trunc_with_err` and `eigh_trunc`, this corresponds to the 2-norm difference between the original and the truncated matrix.
For the case of `eig_trunc`, this interpretation does not hold because the norm of the non-unitary matrix of eigenvectors and its inverse also influence the approximation quality.


For example:
```jldoctest truncations; output=false
using LinearAlgebra: norm
U, S, Vᴴ, ϵ = svd_trunc(A; trunc=truncrank(2))
U, S, Vᴴ, ϵ = svd_trunc_with_err(A; trunc=truncrank(2))
norm(A - U * S * Vᴴ) ≈ ϵ # ϵ is the 2-norm of the discarded singular values

# output
Expand Down
27 changes: 23 additions & 4 deletions ext/MatrixAlgebraKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,15 @@ for svd_f in (:svd_compact, :svd_full)
end
end

function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlgorithm)
function ChainRulesCore.rrule(::typeof(svd_trunc_with_err!), A, USVᴴ, alg::TruncatedAlgorithm)
Ac = copy_input(svd_compact, A)
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
ϵ = truncation_error(diagview(USVᴴ[2]), ind)
return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind)
return (USVᴴ′..., ϵ), _make_svd_trunc_with_err_pullback(A, USVᴴ, ind)
end
function _make_svd_trunc_pullback(A, USVᴴ, ind)
function svd_trunc_pullback(ΔUSVᴴϵ)
function _make_svd_trunc_with_err_pullback(A, USVᴴ, ind)
function svd_trunc_with_err_pullback(ΔUSVᴴϵ)
ΔA = zero(A)
ΔU, ΔS, ΔVᴴ, Δϵ = ΔUSVᴴϵ
if !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ))
Expand All @@ -187,6 +187,25 @@ function _make_svd_trunc_pullback(A, USVᴴ, ind)
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind)
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function svd_trunc_with_err_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
end
return svd_trunc_with_err_pullback
end

function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlgorithm)
Ac = copy_input(svd_compact, A)
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
return USVᴴ′, _make_svd_trunc_pullback(A, USVᴴ, ind)
end
function _make_svd_trunc_pullback(A, USVᴴ, ind)
function svd_trunc_pullback(ΔUSVᴴ)
ΔA = zero(A)
ΔU, ΔS, ΔVᴴ = ΔUSVᴴ
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind)
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
end
Expand Down
36 changes: 32 additions & 4 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,14 +303,14 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
return S_codual, svd_vals_adjoint
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual)
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_with_err), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_with_err)}, A_dA::CoDual, alg_dalg::CoDual)
# compute primal
A_ = Mooncake.primal(A_dA)
dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
alg = Mooncake.primal(alg_dalg)
output = svd_trunc(A, alg)
output = svd_trunc_with_err(A, alg)
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
# pass). For many types this is done automatically when the forward step returns, but
Expand All @@ -319,7 +319,35 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real}
Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual)
dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual)
abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"
abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc_with_err does not yet support non-zero tangent for the truncation error"
U, dU = arrayify(Utrunc, dUtrunc_)
S, dS = arrayify(Strunc, dStrunc_)
Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_)
svd_trunc_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
MatrixAlgebraKit.zero!(dU)
MatrixAlgebraKit.zero!(dS)
MatrixAlgebraKit.zero!(dVᴴ)
return NoRData(), NoRData(), NoRData()
end
return output_codual, svd_trunc_adjoint
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual)
# compute primal
A_ = Mooncake.primal(A_dA)
dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
alg = Mooncake.primal(alg_dalg)
output = svd_trunc(A, alg)
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
# pass). For many types this is done automatically when the forward step returns, but
# not for nested structs with various fields (like Diagonal{Complex})
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
function svd_trunc_adjoint(::NoRData)
Utrunc, Strunc, Vᴴtrunc = Mooncake.primal(output_codual)
dUtrunc_, dStrunc_, dVᴴtrunc_ = Mooncake.tangent(output_codual)
U, dU = arrayify(Utrunc, dUtrunc_)
S, dS = arrayify(Strunc, dStrunc_)
Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_)
Expand Down
4 changes: 2 additions & 2 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ export project_hermitian, project_antihermitian, project_isometric
export project_hermitian!, project_antihermitian!, project_isometric!
export qr_compact, qr_full, qr_null, lq_compact, lq_full, lq_null
export qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!
export svd_compact, svd_full, svd_vals, svd_trunc
export svd_compact!, svd_full!, svd_vals!, svd_trunc!
export svd_compact, svd_full, svd_vals, svd_trunc, svd_trunc_with_err
export svd_compact!, svd_full!, svd_vals!, svd_trunc!, svd_trunc_with_err!
export eigh_full, eigh_vals, eigh_trunc
export eigh_full!, eigh_vals!, eigh_trunc!
export eig_full, eig_vals, eig_trunc
Expand Down
80 changes: 60 additions & 20 deletions src/implementations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ copy_input(::typeof(svd_full), A::AbstractMatrix) = copy!(similar(A, float(eltyp
copy_input(::typeof(svd_compact), A) = copy_input(svd_full, A)
copy_input(::typeof(svd_vals), A) = copy_input(svd_full, A)
copy_input(::typeof(svd_trunc), A) = copy_input(svd_compact, A)
copy_input(::typeof(svd_trunc_with_err), A) = copy_input(svd_compact, A)

copy_input(::typeof(svd_full), A::Diagonal) = copy(A)

Expand Down Expand Up @@ -92,6 +93,9 @@ end
function initialize_output(::typeof(svd_trunc!), A, alg::TruncatedAlgorithm)
return initialize_output(svd_compact!, A, alg.alg)
end
function initialize_output(::typeof(svd_trunc_with_err!), A, alg::TruncatedAlgorithm)
return initialize_output(svd_compact!, A, alg.alg)
end

function initialize_output(::typeof(svd_full!), A::Diagonal, ::DiagonalAlgorithm)
TA = eltype(A)
Expand Down Expand Up @@ -206,19 +210,16 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
return S
end

function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ}
ϵ = similar(A, real(eltype(A)), compute_error)
(U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg)
return compute_error ? (U, S, Vᴴ, norm(ϵ)) : (U, S, Vᴴ, -one(eltype(ϵ)))
function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm)
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
return USVᴴtrunc
end

function svd_trunc!(A, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm) where {TU, TS, TVᴴ, Tϵ}
U, S, Vᴴ, ϵ = USVᴴϵ
U, S, Vᴴ = svd_compact!(A, (U, S, Vᴴ), alg.alg)
function svd_trunc_with_err!(A, USVᴴ, alg::TruncatedAlgorithm)
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
if !isempty(ϵ)
ϵ .= truncation_error!(diagview(S), ind)
end
ϵ = truncation_error!(diagview(S), ind)
return USVᴴtrunc..., ϵ
end

Expand Down Expand Up @@ -287,6 +288,22 @@ function check_input(
return nothing
end

function check_input(
Comment thread
kshyatt marked this conversation as resolved.
Outdated
::typeof(svd_trunc_with_err!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized
)
m, n = size(A)
minmn = min(m, n)
U, S, Vᴴ = USVᴴ
@assert U isa AbstractMatrix && S isa Diagonal && Vᴴ isa AbstractMatrix
@check_size(U, (m, m))
@check_scalar(U, A)
@check_size(S, (minmn, minmn))
@check_scalar(S, A, real)
@check_size(Vᴴ, (n, n))
@check_scalar(Vᴴ, A)
return nothing
end

function initialize_output(
::typeof(svd_trunc!), A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized}
)
Expand All @@ -298,6 +315,17 @@ function initialize_output(
return (U, S, Vᴴ)
end

function initialize_output(
::typeof(svd_trunc_with_err!), A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized}
Comment thread
kshyatt marked this conversation as resolved.
Outdated
)
m, n = size(A)
minmn = min(m, n)
U = similar(A, (m, m))
S = Diagonal(similar(A, real(eltype(A)), (minmn,)))
Vᴴ = similar(A, (n, n))
return (U, S, Vᴴ)
end

function _gpu_gesvd!(
A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix
)
Expand Down Expand Up @@ -372,22 +400,34 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
return USVᴴ
end

function svd_trunc!(A::AbstractMatrix, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm{<:GPU_Randomized}) where {TU, TS, TVᴴ, Tϵ}
U, S, Vᴴ, ϵ = USVᴴϵ
function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized})
U, S, Vᴴ = USVᴴ
check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg)
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...)

# TODO: make sure that truncation is based on maxrank, otherwise this might be wrong
(Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)

if !isempty(ϵ)
# normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
normS = norm(diagview(Str))
normA = norm(A)
# equivalent to sqrt(normA^2 - normS^2)
# but may be more accurate
ϵ = sqrt((normA + normS) * (normA - normS))
end
do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool
do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr)

return Utr, Str, Vᴴtr
end

function svd_trunc_with_err!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized})
U, S, Vᴴ = USVᴴ
check_input(svd_trunc_with_err!, A, (U, S, Vᴴ), alg.alg)
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...)

# TODO: make sure that truncation is based on maxrank, otherwise this might be wrong
(Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)

# normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
normS = norm(diagview(Str))
normA = norm(A)
# equivalent to sqrt(normA^2 - normS^2)
# but may be more accurate
ϵ = sqrt((normA + normS) * (normA - normS))

do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool
do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr)
Expand Down
74 changes: 62 additions & 12 deletions src/interface/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_vals(!)`](@ref svd_vals) and
@functiondef svd_compact

"""
svd_trunc(A; [trunc], kwargs...) -> U, S, Vᴴ, ϵ
svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ
svd_trunc!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ, ϵ
svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ
svd_trunc_with_err(A; [trunc], kwargs...) -> U, S, Vᴴ, ϵ
svd_trunc_with_err(A, alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ
svd_trunc_with_err!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ, ϵ
svd_trunc_with_err!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ

Compute a partial or truncated singular value decomposition (SVD) of `A`, such that
`A * (Vᴴ)' ≈ U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size
Expand Down Expand Up @@ -81,6 +81,54 @@ for the default algorithm selection behavior.
When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be specified as the
truncation strategy is already embedded in the algorithm.

!!! note
The bang method `svd_trunc!` optionally accepts the output structure and
possibly destroys the input matrix `A`. Always use the return value of the function
as it may not always be possible to use the provided `USVᴴ` as output.

See also [`svd_trunc(!)`](@ref svd_trunc), [`svd_full(!)`](@ref svd_full),
[`svd_compact(!)`](@ref svd_compact), [`svd_vals(!)`](@ref svd_vals),
and [Truncations](@ref) for more information on truncation strategies.
"""
@functiondef svd_trunc_with_err

"""
svd_trunc(A; [trunc], kwargs...) -> U, S, Vᴴ
svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ
svd_trunc!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ
svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ

Compute a partial or truncated singular value decomposition (SVD) of `A`, such that
`A * (Vᴴ)' ≈ U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size
`(m, k)`, whereas `Vᴴ` is a matrix of size `(k, n)` with orthonormal rows and `S` is a
square diagonal matrix of size `(k, k)`, with `k` is set by the truncation strategy.

## Truncation
The truncation strategy can be controlled via the `trunc` keyword argument. This can be
either a `NamedTuple` or a [`TruncationStrategy`](@ref). If `trunc` is not provided or
nothing, all values will be kept.

### `trunc::NamedTuple`
The supported truncation keyword arguments are:

$docs_truncation_kwargs

### `trunc::TruncationStrategy`
For more control, a truncation strategy can be supplied directly.
By default, MatrixAlgebraKit supplies the following:

$docs_truncation_strategies

## Keyword arguments
Other keyword arguments are passed to the algorithm selection procedure. If no explicit
`alg` is provided, these keywords are used to select and configure the algorithm through
[`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm
selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref)
for the default algorithm selection behavior.

When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be specified as the
truncation strategy is already embedded in the algorithm.

!!! note
The bang method `svd_trunc!` optionally accepts the output structure and
possibly destroys the input matrix `A`. Always use the return value of the function
Expand Down Expand Up @@ -125,13 +173,15 @@ for f in (:svd_full!, :svd_compact!, :svd_vals!)
end
end

function select_algorithm(::typeof(svd_trunc!), A, alg; trunc = nothing, kwargs...)
if alg isa TruncatedAlgorithm
isnothing(trunc) ||
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
return alg
else
alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_svd, select_truncation(trunc))
for f in (:svd_trunc!, :svd_trunc_with_err!)
@eval function select_algorithm(::typeof($f), A, alg; trunc = nothing, kwargs...)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do the Union{typeof(svd_trunc!), typeof(svd_trunc_no_error!)} again, or is there an advantage to this approach?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I did it this way mostly to follow the convention above, not sure it matters much?

if alg isa TruncatedAlgorithm
isnothing(trunc) ||
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
return alg
else
alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_svd, select_truncation(trunc))
end
end
end
4 changes: 2 additions & 2 deletions test/amd/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,14 @@ end
# minmn = min(m, n)
# r = minmn - 2
#
# U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc=truncrank(r))
# U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc_with_err(A; alg, trunc=truncrank(r))
# @test length(S1.diag) == r
# @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1]
#
# s = 1 + sqrt(eps(real(T)))
# trunc2 = trunctol(; atol=s * S₀[r + 1])
#
# U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc=trunctol(; atol=s * S₀[r + 1]))
# U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc_with_err(A; alg, trunc=trunctol(; atol=s * S₀[r + 1]))
# @test length(S2.diag) == r
# @test U1 ≈ U2
# @test S1 ≈ S2
Expand Down
Loading