Skip to content

Commit e0dd0ef

Browse files
committed
Update right_orth
1 parent 19fd994 commit e0dd0ef

4 files changed

Lines changed: 70 additions & 73 deletions

File tree

src/implementations/orthnull.jl

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -121,38 +121,33 @@ function left_orth!(A::AbstractMatrix, VC; trunc=nothing,
121121
end
122122
end
123123

124-
function right_orth!(A::AbstractMatrix, CVᴴ; kwargs...)
124+
function right_orth!(A::AbstractMatrix, CVᴴ; trunc=nothing,
125+
kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true),
126+
alg_polar=(;), alg_svd=(;))
125127
check_input(right_orth!, A, CVᴴ)
126-
atol = get(kwargs, :atol, 0)
127-
rtol = get(kwargs, :rtol, 0)
128-
kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :lqpos : :svd)
129-
if !(iszero(atol) && iszero(rtol)) && kind != :svd
130-
throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind"))
128+
if !isnothing(trunc) && kind != :svd
129+
throw(ArgumentError("truncation not supported for right_orth with kind=$kind"))
131130
end
132131
if kind == :lq
133-
alg = get(kwargs, :alg, select_algorithm(lq_compact!, A))
134-
return lq_compact!(A, CVᴴ, alg)
135-
elseif kind == :lqpos
136-
alg = get(kwargs, :alg, select_algorithm(lq_compact!, A; positive=true))
137-
return lq_compact!(A, CVᴴ, alg)
132+
alg_lq′ = algorithm_or_select_algorithm(lq_compact!, A, alg_lq)
133+
return lq_compact!(A, CVᴴ, alg_lq′)
138134
elseif kind == :polar
139135
size(A, 2) >= size(A, 1) ||
140136
throw(ArgumentError("`right_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m <= n`"))
141-
alg = get(kwargs, :alg, select_algorithm(right_polar!, A))
142-
return right_polar!(A, CVᴴ, alg)
143-
elseif kind == :svd && iszero(atol) && iszero(rtol)
144-
alg = get(kwargs, :alg, select_algorithm(svd_compact!, A))
137+
alg_polar′ = algorithm_or_select_algorithm(right_polar!, A, alg_polar)
138+
return right_polar!(A, CVᴴ, alg_polar′)
139+
elseif kind == :svd && isnothing(trunc)
140+
alg_svd′ = algorithm_or_select_algorithm(svd_compact!, A, alg_svd)
145141
C, Vᴴ = CVᴴ
146-
S = Diagonal(initialize_output(svd_vals!, A, alg))
147-
U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg)
142+
S = Diagonal(initialize_output(svd_vals!, A, alg_svd′))
143+
U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg_svd′)
148144
return rmul!(U, S), Vᴴ
149145
elseif kind == :svd
150-
alg_svd = select_algorithm(svd_compact!, A)
151-
trunc = TruncationKeepAbove(atol, rtol)
152-
alg = get(kwargs, :alg, TruncatedAlgorithm(alg_svd, trunc))
146+
alg_svd′ = algorithm_or_select_algorithm(svd_compact!, A, alg_svd)
147+
alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′)
153148
C, Vᴴ = CVᴴ
154-
S = Diagonal(initialize_output(svd_vals!, A, alg_svd))
155-
U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg)
149+
S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg))
150+
U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_svd_trunc)
156151
return rmul!(U, S), Vᴴ
157152
else
158153
throw(ArgumentError("`right_orth!` received unknown value `kind = $kind`"))

src/interface/orthnull.jl

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,43 +19,44 @@ end
1919
# Orth functions
2020
# --------------
2121
"""
22-
left_orth(A; [kind::Symbol, atol::Real=0, rtol::Real=0, alg]) -> V, C
23-
left_orth!(A, [VC]; [kind::Symbol, atol::Real=0, rtol::Real=0, alg]) -> V, C
22+
left_orth(A; [kind::Symbol, trunc, alg_qr, alg_polar, alg_svd]) -> V, C
23+
left_orth!(A, [VC]; [kind::Symbol, trunc, alg_qr, alg_polar, alg_svd]) -> V, C
2424
2525
Compute an orthonormal basis `V` for the image of the matrix `A` of size `(m, n)`,
2626
as well as a matrix `C` (the corestriction) such that `A` factors as `A = V * C`.
2727
The keyword argument `kind` can be used to specify the specific orthogonal decomposition
28-
that should be used to factor `A`, whereas `atol` and `rtol` can be used to control the
28+
that should be used to factor `A`, whereas `trunc` can be used to control the
2929
precision in determining the rank of `A` via its singular values.
3030
3131
This is a high-level wrapper and will use one of the decompositions
32-
`qr!`, `svd!`, and `left_polar!` to compute the orthogonal basis `V`, as controlled
32+
`qr_compact!`, `svd_compact!`/`svd_trunc!`, and `left_polar!` to compute the orthogonal basis `V`, as controlled
3333
by the keyword arguments.
3434
3535
When `kind` is provided, its possible values are
3636
37-
* `kind == :qrpos`: `V` and `C` are computed using the positive QR decomposition.
38-
This requires `iszero(atol) && iszero(rtol)` and `left_orth!(A, [VC])` is equivalent to
37+
* `kind == :qr`: `V` and `C` are computed using the QR decomposition.
38+
This requires `isnothing(trunc)` and `left_orth!(A, [VC])` is equivalent to
3939
`qr_compact!(A, [VC], alg)` with a default value `alg = select_algorithm(qr_compact!, A; positive=true)`
4040
41-
* `kind == :qr`: `V` and `C` are computed using the QR decomposition,
42-
This requires `iszero(atol) && iszero(rtol)` and `left_orth!(A, [VC])` is equivalent to
43-
`qr_compact!(A, [VC], alg)` with a default value `alg = select_algorithm(qr_compact!, A)`
44-
4541
* `kind == :polar`: `V` and `C` are computed using the polar decomposition,
46-
This requires `iszero(atol) && iszero(rtol)` and `left_orth!(A, [VC])` is equivalent to
42+
This requires `isnothing(trunc)` and `left_orth!(A, [VC])` is equivalent to
4743
`left_polar!(A, [VC], alg)` with a default value `alg = select_algorithm(left_polar!, A)`
4844
49-
* `kind == :svd`: `V` and `C` are computed using the singular value decomposition `svd_trunc!`,
50-
where `V` will contain the left singular vectors corresponding to the singular values that
51-
are larger than `max(atol, rtol * σ₁)`, where `σ₁` is the largest singular value of `A`.
52-
`C` is computed as the product of the singular values and the right singular vectors,
53-
i.e. with `U, S, Vᴴ = svd_trunc!(A)`, we have `V = U` and `C = S * Vᴴ`.
45+
* `kind == :svd`: `V` and `C` are computed using the singular value decomposition `svd_compact!`
46+
if no truncation is specified through the `trunc` keyword argument or `svd_trunc!`
47+
if truncation is specified through the `trunc` keyword argument.
48+
`V` will contain the left singular vectors and `C` is computed as the product of the singular
49+
values and the right singular vectors, i.e. with `U, S, Vᴴ = svd(A)`, we have
50+
`V = U` and `C = S * Vᴴ`.
5451
55-
When `kind` is not provided, the default value is `:qrpos` when `iszero(atol) && iszero(rtol)`
52+
When `kind` is not provided, the default value is `:qr` when `isnothing(trunc)`
5653
and `:svd` otherwise. Finally, finer control is obtained by providing an explicit algorithm
57-
using the `alg` keyword argument, which should be compatible with the chosen or default value
58-
of `kind`.
54+
for backend factorizations through the `alg_qr`, `alg_polar`, and `alg_svd` keyword arguments,
55+
which will only be used if the corresponding factorization is called based on the other inputs.
56+
If NamedTuples are passed as `alg_qr`, `alg_polar`, or `alg_svd`, a default algorithm is chosen
57+
with `select_algorithm` and the NamedTuple is passed as keyword arguments to that algorithm.
58+
`alg_qr` defaults to `(; positive=true)` so that by default a positive QR decomposition will
59+
be used.
5960
6061
!!! note
6162
The bang method `left_orth!` optionally accepts the output structure and possibly destroys
@@ -80,37 +81,38 @@ end
8081
Compute an orthonormal basis `V = adjoint(Vᴴ)` for the coimage of the matrix `A`, i.e.
8182
for the image of `adjoint(A)`, as well as a matrix `C` such that `A = C * Vᴴ`.
8283
The keyword argument `kind` can be used to specify the specific orthogonal decomposition
83-
that should be used to factor `A`, whereas `atol` and `rtol` can be used to control the
84+
that should be used to factor `A`, whereas `trunc` can be used to control the
8485
precision in determining the rank of `A` via its singular values.
8586
8687
This is a high-level wrapper and will use call one of the decompositions
87-
`qr!`, `svd!`, and `left_polar!` to compute the orthogonal basis `V`, as controlled
88-
by the keyword arguments.
88+
`lq_compact!`, `svd_compact!`/`svd_trunc!`, and `right_polar!` to compute the
89+
orthogonal basis `V`, as controlled by the keyword arguments.
8990
9091
When `kind` is provided, its possible values are
9192
92-
* `kind == :lqpos`: `C` and `Vᴴ` are computed using the positive QR decomposition.
93-
This requires `iszero(atol) && iszero(rtol)` and `right_orth!(A, [CVᴴ])` is equivalent to
94-
`lq_compact!(A, [CVᴴ], alg)` with a default value `alg = select_algorithm(lq_compact!, A; positive=true)`
95-
9693
* `kind == :lq`: `C` and `Vᴴ` are computed using the QR decomposition,
97-
This requires `iszero(atol) && iszero(rtol)` and `right_orth!(A, [CVᴴ])` is equivalent to
98-
`lq_compact!(A, [CVᴴ], alg)` with a default value `alg = select_algorithm(lq_compact!, A))`
94+
This requires `isnothing(trunc)` and `right_orth!(A, [CVᴴ])` is equivalent to
95+
`lq_compact!(A, [CVᴴ], alg)` with a default value `alg = select_algorithm(lq_compact!, A; positive=true)`
9996
10097
* `kind == :polar`: `C` and `Vᴴ` are computed using the polar decomposition,
101-
This requires `iszero(atol) && iszero(rtol)` and `right_orth!(A, [CVᴴ])` is equivalent to
102-
`right_polar!(A, [CVᴴ], alg)` with a default value `alg = select_algorithm(right_polar!, A))`
98+
This requires `isnothing(trunc)` and `right_orth!(A, [CVᴴ])` is equivalent to
99+
`right_polar!(A, [CVᴴ], alg)` with a default value `alg = select_algorithm(right_polar!, A)`
103100
104-
* `kind == :svd`: `C` and `Vᴴ` are computed using the singular value decomposition `svd_trunc!`,
105-
where `V = adjoint(Vᴴ)` will contain the right singular vectors corresponding to the singular
106-
values that are larger than `max(atol, rtol * σ₁)`, where `σ₁` is the largest singular value of `A`.
107-
`C` is computed as the product of the singular values and the right singular vectors,
108-
i.e. with `U, S, Vᴴ = svd_trunc!(A)`, we have `C = rmul!(U, S)` and `Vᴴ = Vᴴ`.
101+
* `kind == :svd`: `C` and `Vᴴ` are computed using the singular value decomposition `svd_compact!`
102+
if no truncation is specified through the `trunc` keyword argument or `svd_trunc!`
103+
if truncation is specified through the `trunc` keyword argument.
104+
`V = adjoint(Vᴴ)` will contain the right singular vectors corresponding to the singular
105+
values and `C` is computed as the product of the singular values and the right singular vectors,
106+
i.e. with `U, S, Vᴴ = svd(A)`, we have `C = rmul!(U, S)` and `Vᴴ = Vᴴ`.
109107
110-
When `kind` is not provided, the default value is `:lqpos` when `iszero(atol) && iszero(rtol)`
108+
When `kind` is not provided, the default value is `:lq` when `isnothing(trunc)`
111109
and `:svd` otherwise. Finally, finer control is obtained by providing an explicit algorithm
112-
using the `alg` keyword argument, which should be compatible with the chosen or default value
113-
of `kind`.
110+
for backend factorizations through the `alg_lq`, `alg_polar`, and `alg_svd` keyword arguments,
111+
which will only be used if the corresponding factorization is called based on the other inputs.
112+
If NamedTuples are passed as `alg_lq`, `alg_polar`, or `alg_svd`, a default algorithm is chosen
113+
with `select_algorithm` and the NamedTuple is passed as keyword arguments to that algorithm.
114+
`alg_lq` defaults to `(; positive=true)` so that by default a positive QR decomposition will
115+
be used.
114116
115117
!!! note
116118
The bang method `right_orth!` optionally accepts the output structure and possibly destroys

test/chainrules.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -338,26 +338,26 @@ end
338338
config = Zygote.ZygoteRuleConfig()
339339
test_rrule(config, left_orth, A;
340340
atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false)
341-
test_rrule(config, left_orth, A; fkwargs=(; kind=:qrpos),
341+
test_rrule(config, left_orth, A; fkwargs=(; kind=:qr),
342342
atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false)
343343
m >= n &&
344344
test_rrule(config, left_orth, A; fkwargs=(; kind=:polar),
345345
atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false)
346346

347-
ΔN = left_orth(A; kind=:qrpos)[1] * randn(rng, T, min(m, n), m - min(m, n))
348-
test_rrule(config, left_null, A; fkwargs=(; kind=:qrpos), output_tangent=ΔN,
347+
ΔN = left_orth(A; kind=:qr)[1] * randn(rng, T, min(m, n), m - min(m, n))
348+
test_rrule(config, left_null, A; fkwargs=(; kind=:qr), output_tangent=ΔN,
349349
atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false)
350350

351351
test_rrule(config, right_orth, A;
352352
atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false)
353-
test_rrule(config, right_orth, A; fkwargs=(; kind=:lqpos),
353+
test_rrule(config, right_orth, A; fkwargs=(; kind=:lq),
354354
atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false)
355355
m <= n &&
356356
test_rrule(config, right_orth, A; fkwargs=(; kind=:polar),
357357
atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false)
358358

359-
ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; kind=:lqpos)[2]
360-
test_rrule(config, right_null, A; fkwargs=(; kind=:lqpos), output_tangent=ΔNᴴ,
359+
ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; kind=:lq)[2]
360+
test_rrule(config, right_null, A; fkwargs=(; kind=:lq), output_tangent=ΔNᴴ,
361361
atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false)
362362
end
363363
end

test/orthnull.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ end
141141
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 I
142142

143143
atol = eps(real(T))
144-
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); atol=atol)
144+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc=(; atol=atol))
145145
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; atol=atol)
146146
@test C2 !== C
147147
@test Vᴴ2 !== Vᴴ
@@ -153,7 +153,7 @@ end
153153
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 I
154154

155155
rtol = eps(real(T))
156-
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); rtol=rtol)
156+
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc=(; rtol=rtol))
157157
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; rtol=rtol)
158158
@test C2 !== C
159159
@test Vᴴ2 !== Vᴴ
@@ -164,7 +164,7 @@ end
164164
@test Nᴴ2 * Nᴴ2' I
165165
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 I
166166

167-
for kind in (:lq, :lqpos, :polar, :svd)
167+
for kind in (:lq, :polar, :svd)
168168
n < m && kind == :polar && continue
169169
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind)
170170
@test C2 === C
@@ -181,7 +181,7 @@ end
181181

182182
if kind == :svd
183183
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
184-
atol=atol)
184+
trunc=(; atol=atol))
185185
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind, atol=atol)
186186
@test C2 !== C
187187
@test Vᴴ2 !== Vᴴ
@@ -193,7 +193,7 @@ end
193193
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 I
194194

195195
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
196-
rtol=rtol)
196+
trunc=(; rtol=rtol))
197197
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind, rtol=rtol)
198198
@test C2 !== C
199199
@test Vᴴ2 !== Vᴴ
@@ -205,9 +205,9 @@ end
205205
@test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 I
206206
else
207207
@test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
208-
atol=atol)
208+
trunc=(; atol=atol))
209209
@test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind,
210-
rtol=rtol)
210+
trunc=(; rtol=rtol))
211211
@test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; kind=kind,
212212
atol=atol)
213213
@test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; kind=kind,

0 commit comments

Comments
 (0)