Skip to content

Commit 523eef1

Browse files
committed
Merge branch 'main' into copilot/update-docstrings-eigenvalue-decompositions
2 parents cef025d + bd96dfb commit 523eef1

18 files changed

Lines changed: 372 additions & 138 deletions

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ MatrixAlgebraKit.iszerotangent(::AbstractZero) = true
1414
@non_differentiable MatrixAlgebraKit.select_algorithm(args...)
1515
@non_differentiable MatrixAlgebraKit.initialize_output(args...)
1616
@non_differentiable MatrixAlgebraKit.check_input(args...)
17-
@non_differentiable MatrixAlgebraKit.isisometry(args...)
17+
@non_differentiable MatrixAlgebraKit.isisometric(args...)
1818
@non_differentiable MatrixAlgebraKit.isunitary(args...)
1919

2020
function ChainRulesCore.rrule(::typeof(copy_input), f, A)

src/MatrixAlgebraKit.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@ module MatrixAlgebraKit
33
using LinearAlgebra: LinearAlgebra
44
using LinearAlgebra: norm # TODO: eleminate if we use VectorInterface.jl?
55
using LinearAlgebra: mul!, rmul!, lmul!, adjoint!, rdiv!, ldiv!
6-
using LinearAlgebra: sylvester
6+
using LinearAlgebra: sylvester, lu!
77
using LinearAlgebra: isposdef, issymmetric
88
using LinearAlgebra: Diagonal, diag, diagind, isdiag
99
using LinearAlgebra: UpperTriangular, LowerTriangular
1010
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt
1111

12-
export isisometry, isunitary, ishermitian, isantihermitian
12+
export isisometric, isunitary, ishermitian, isantihermitian
1313

14-
export project_hermitian, project_antihermitian
15-
export project_hermitian!, project_antihermitian!
14+
export project_hermitian, project_antihermitian, project_isometric
15+
export project_hermitian!, project_antihermitian!, project_isometric!
1616
export qr_compact, qr_full, qr_null, lq_compact, lq_full, lq_null
1717
export qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!
1818
export svd_compact, svd_full, svd_vals, svd_trunc
@@ -34,6 +34,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
3434
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
3535
LAPACK_DivideAndConquer, LAPACK_Jacobi
3636
export LQViaTransposedQR
37+
export PolarViaSVD, PolarNewton
3738
export DiagonalAlgorithm
3839
export NativeBlocked
3940
export CUSOLVER_Simple, CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar,
@@ -64,6 +65,7 @@ export notrunc, truncrank, trunctol, truncerror, truncfilter
6465
:svd_pullback!, :svd_trunc_pullback!
6566
)
6667
)
68+
eval(Expr(:public, :is_left_isometric, :is_right_isometric))
6769
end
6870

6971
include("common/defaults.jl")

src/common/matrixproperties.jl

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
isisometry(A; side=:left, isapprox_kwargs...) -> Bool
2+
isisometric(A; side=:left, isapprox_kwargs...) -> Bool
33
44
Test whether a linear map is an isometry, where the type of isometry is controlled by `kind`:
55
@@ -8,13 +8,14 @@ Test whether a linear map is an isometry, where the type of isometry is controll
88
99
The `isapprox_kwargs` are passed on to `isapprox` to control the tolerances.
1010
11-
New specializations should overload [`is_left_isometry`](@ref) and [`is_right_isometry`](@ref).
11+
New specializations should overload [`MatrixAlgebraKit.is_left_isometric`](@ref) and
12+
[`MatrixAlgebraKit.is_right_isometric`](@ref).
1213
1314
See also [`isunitary`](@ref).
1415
"""
15-
function isisometry(A; side::Symbol = :left, isapprox_kwargs...)
16-
side === :left && return is_left_isometry(A; isapprox_kwargs...)
17-
side === :right && return is_right_isometry(A; isapprox_kwargs...)
16+
function isisometric(A; side::Symbol = :left, isapprox_kwargs...)
17+
side === :left && return is_left_isometric(A; isapprox_kwargs...)
18+
side === :right && return is_right_isometric(A; isapprox_kwargs...)
1819

1920
throw(ArgumentError(lazy"Invalid isometry side: $side"))
2021
end
@@ -25,48 +26,42 @@ end
2526
Test whether a linear map is unitary, i.e. `A * A' ≈ I ≈ A' * A`.
2627
The `isapprox_kwargs` are passed on to `isapprox` to control the tolerances.
2728
28-
See also [`isisometry`](@ref).
29+
See also [`isisometric`](@ref).
2930
"""
3031
function isunitary(A; isapprox_kwargs...)
31-
return is_left_isometry(A; isapprox_kwargs...) &&
32-
is_right_isometry(A; isapprox_kwargs...)
32+
return is_left_isometric(A; isapprox_kwargs...) &&
33+
is_right_isometric(A; isapprox_kwargs...)
3334
end
3435
function isunitary(A::AbstractMatrix; isapprox_kwargs...)
3536
size(A, 1) == size(A, 2) || return false
36-
return is_left_isometry(A; isapprox_kwargs...)
37+
return is_left_isometric(A; isapprox_kwargs...)
3738
end
3839

3940
@doc """
40-
is_left_isometry(A; isapprox_kwargs...) -> Bool
41+
is_left_isometric(A; isapprox_kwargs...) -> Bool
4142
42-
Test whether a linear map is a left isometry, i.e. `A' * A ≈ I`.
43+
Test whether a linear map is a (left) isometry, i.e. `A' * A ≈ I`.
4344
The `isapprox_kwargs` can be used to control the tolerances of the equality.
4445
45-
See also [`isisometry`](@ref) and [`is_right_isometry`](@ref).
46-
""" is_left_isometry
46+
See also [`isisometric`](@ref) and [`MatrixAlgebraKit.is_right_isometric`](@ref).
47+
""" is_left_isometric
4748

48-
function is_left_isometry(A::AbstractMatrix; atol::Real = 0, rtol::Real = defaulttol(A), norm = LinearAlgebra.norm)
49+
function is_left_isometric(A::AbstractMatrix; atol::Real = 0, rtol::Real = defaulttol(A), norm = LinearAlgebra.norm)
4950
P = A' * A
5051
nP = norm(P) # isapprox would use `rtol * max(norm(P), norm(I))`
5152
diagview(P) .-= 1
5253
return norm(P) <= max(atol, rtol * nP) # assume that the norm of I is `sqrt(n)`
5354
end
5455

5556
@doc """
56-
is_right_isometry(A; isapprox_kwargs...) -> Bool
57+
is_right_isometric(A; isapprox_kwargs...) -> Bool
5758
58-
Test whether a linear map is a right isometry, i.e. `A * A' ≈ I`.
59+
Test whether a linear map is a (right) isometry, i.e. `A * A' ≈ I`.
5960
The `isapprox_kwargs` can be used to control the tolerances of the equality.
6061
61-
See also [`isisometry`](@ref) and [`is_left_isometry`](@ref).
62-
""" is_right_isometry
63-
64-
function is_right_isometry(A::AbstractMatrix; atol::Real = 0, rtol::Real = defaulttol(A), norm = LinearAlgebra.norm)
65-
P = A * A'
66-
nP = norm(P) # isapprox would use `rtol * max(norm(P), norm(I))`
67-
diagview(P) .-= 1
68-
return norm(P) <= max(atol, rtol * nP) # assume that the norm of I is `sqrt(n)`
69-
end
62+
See also [`isisometric`](@ref) and [`MatrixAlgebraKit.is_left_isometric`](@ref).
63+
""" is_right_isometric
64+
is_right_isometric(A; kwargs...) = is_left_isometric(A'; kwargs...)
7065

7166
"""
7267
ishermitian(A; isapprox_kwargs...)

src/implementations/polar.jl

Lines changed: 141 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ function check_input(::typeof(left_polar!), A::AbstractMatrix, WP, ::AbstractAlg
1111
@assert W isa AbstractMatrix && P isa AbstractMatrix
1212
@check_size(W, (m, n))
1313
@check_scalar(W, A)
14-
@check_size(P, (n, n))
14+
isempty(P) || @check_size(P, (n, n))
1515
@check_scalar(P, A)
1616
return nothing
1717
end
@@ -21,7 +21,7 @@ function check_input(::typeof(right_polar!), A::AbstractMatrix, PWᴴ, ::Abstrac
2121
n >= m ||
2222
throw(ArgumentError("input matrix needs at least as many columns as rows"))
2323
@assert P isa AbstractMatrix && Wᴴ isa AbstractMatrix
24-
@check_size(P, (m, m))
24+
isempty(P) || @check_size(P, (m, m))
2525
@check_scalar(P, A)
2626
@check_size(Wᴴ, (m, n))
2727
@check_scalar(Wᴴ, A)
@@ -43,25 +43,154 @@ function initialize_output(::typeof(right_polar!), A::AbstractMatrix, ::Abstract
4343
return (P, Wᴴ)
4444
end
4545

46-
# Implementation
47-
# --------------
46+
# Implementation via SVD
47+
# -----------------------
4848
function left_polar!(A::AbstractMatrix, WP, alg::PolarViaSVD)
4949
check_input(left_polar!, A, WP, alg)
50-
U, S, Vᴴ = svd_compact!(A, alg.svdalg)
50+
U, S, Vᴴ = svd_compact!(A, alg.svd_alg)
5151
W, P = WP
5252
W = mul!(W, U, Vᴴ)
53-
S .= sqrt.(S)
54-
SsqrtVᴴ = lmul!(S, Vᴴ)
55-
P = mul!(P, SsqrtVᴴ', SsqrtVᴴ)
53+
if !isempty(P)
54+
S .= sqrt.(S)
55+
SsqrtVᴴ = lmul!(S, Vᴴ)
56+
P = mul!(P, SsqrtVᴴ', SsqrtVᴴ)
57+
end
5658
return (W, P)
5759
end
5860
function right_polar!(A::AbstractMatrix, PWᴴ, alg::PolarViaSVD)
5961
check_input(right_polar!, A, PWᴴ, alg)
60-
U, S, Vᴴ = svd_compact!(A, alg.svdalg)
62+
U, S, Vᴴ = svd_compact!(A, alg.svd_alg)
6163
P, Wᴴ = PWᴴ
6264
Wᴴ = mul!(Wᴴ, U, Vᴴ)
63-
S .= sqrt.(S)
64-
USsqrt = rmul!(U, S)
65-
P = mul!(P, USsqrt, USsqrt')
65+
if !isempty(P)
66+
S .= sqrt.(S)
67+
USsqrt = rmul!(U, S)
68+
P = mul!(P, USsqrt, USsqrt')
69+
end
6670
return (P, Wᴴ)
6771
end
72+
73+
# Implementation via Newton
74+
# --------------------------
75+
function left_polar!(A::AbstractMatrix, WP, alg::PolarNewton)
76+
check_input(left_polar!, A, WP, alg)
77+
W, P = WP
78+
if isempty(P)
79+
W = _left_polarnewton!(A, W, P; alg.kwargs...)
80+
return W, P
81+
else
82+
W = _left_polarnewton!(copy(A), W, P; alg.kwargs...)
83+
# we still need `A` to compute `P`
84+
P = project_hermitian!(mul!(P, W', A))
85+
return W, P
86+
end
87+
end
88+
89+
function right_polar!(A::AbstractMatrix, PWᴴ, alg::PolarNewton)
90+
check_input(right_polar!, A, PWᴴ, alg)
91+
P, Wᴴ = PWᴴ
92+
if isempty(P)
93+
Wᴴ = _right_polarnewton!(A, Wᴴ, P; alg.kwargs...)
94+
return P, Wᴴ
95+
else
96+
Wᴴ = _right_polarnewton!(copy(A), Wᴴ, P; alg.kwargs...)
97+
# we still need `A` to compute `P`
98+
P = project_hermitian!(mul!(P, A, Wᴴ'))
99+
return P, Wᴴ
100+
end
101+
end
102+
103+
# these methods only compute W and destroy A in the process
104+
function _left_polarnewton!(A::AbstractMatrix, W, P = similar(A, (0, 0)); tol = defaulttol(A), maxiter = 10)
105+
m, n = size(A) # we must have m >= n
106+
Rᴴinv = isempty(P) ? similar(P, (n, n)) : P # use P as workspace when available
107+
if m > n # initial QR
108+
Q, R = qr_compact!(A)
109+
Rc = view(A, 1:n, 1:n)
110+
copy!(Rc, R)
111+
Rᴴinv = ldiv!(UpperTriangular(Rc)', one!(Rᴴinv))
112+
else # m == n
113+
R = A
114+
Rc = view(W, 1:n, 1:n)
115+
copy!(Rc, R)
116+
Rᴴinv = ldiv!(lu!(Rc)', one!(Rᴴinv))
117+
end
118+
γ = sqrt(norm(Rᴴinv) / norm(R)) # scaling factor
119+
rmul!(R, γ)
120+
rmul!(Rᴴinv, 1 / γ)
121+
R, Rᴴinv = _avgdiff!(R, Rᴴinv)
122+
copy!(Rc, R)
123+
i = 1
124+
conv = norm(Rᴴinv, Inf)
125+
while i < maxiter && conv > tol
126+
Rᴴinv = ldiv!(lu!(Rc)', one!(Rᴴinv))
127+
γ = sqrt(norm(Rᴴinv) / norm(R)) # scaling factor
128+
rmul!(R, γ)
129+
rmul!(Rᴴinv, 1 / γ)
130+
R, Rᴴinv = _avgdiff!(R, Rᴴinv)
131+
copy!(Rc, R)
132+
conv = norm(Rᴴinv, Inf)
133+
i += 1
134+
end
135+
if conv > tol
136+
@warn "`left_polar!` via Newton iteration did not converge within $maxiter iterations (final residual: $conv)"
137+
end
138+
if m > n
139+
return mul!(W, Q, Rc)
140+
end
141+
return W
142+
end
143+
144+
function _right_polarnewton!(A::AbstractMatrix, Wᴴ, P = similar(A, (0, 0)); tol = defaulttol(A), maxiter = 10)
145+
m, n = size(A) # we must have m <= n
146+
Lᴴinv = isempty(P) ? similar(P, (m, m)) : P # use P as workspace when available
147+
if m < n # initial QR
148+
L, Q = lq_compact!(A)
149+
Lc = view(A, 1:m, 1:m)
150+
copy!(Lc, L)
151+
Lᴴinv = ldiv!(LowerTriangular(Lc)', one!(Lᴴinv))
152+
else # m == n
153+
L = A
154+
Lc = view(Wᴴ, 1:m, 1:m)
155+
copy!(Lc, L)
156+
Lᴴinv = ldiv!(lu!(Lc)', one!(Lᴴinv))
157+
end
158+
γ = sqrt(norm(Lᴴinv) / norm(L)) # scaling factor
159+
rmul!(L, γ)
160+
rmul!(Lᴴinv, 1 / γ)
161+
L, Lᴴinv = _avgdiff!(L, Lᴴinv)
162+
copy!(Lc, L)
163+
i = 1
164+
conv = norm(Lᴴinv, Inf)
165+
while i < maxiter && conv > tol
166+
Lᴴinv = ldiv!(lu!(Lc)', one!(Lᴴinv))
167+
γ = sqrt(norm(Lᴴinv) / norm(L)) # scaling factor
168+
rmul!(L, γ)
169+
rmul!(Lᴴinv, 1 / γ)
170+
L, Lᴴinv = _avgdiff!(L, Lᴴinv)
171+
copy!(Lc, L)
172+
conv = norm(Lᴴinv, Inf)
173+
i += 1
174+
end
175+
if conv > tol
176+
@warn "`right_polar!` via Newton iteration did not converge within $maxiter iterations (final residual: $conv)"
177+
end
178+
if m < n
179+
return mul!(Wᴴ, Lc, Q)
180+
end
181+
return Wᴴ
182+
end
183+
184+
# in place computation of the average and difference of two arrays
185+
function _avgdiff!(A::AbstractArray, B::AbstractArray)
186+
axes(A) == axes(B) || throw(DimensionMismatch())
187+
@simd for I in eachindex(A, B)
188+
@inbounds begin
189+
a = A[I]
190+
b = B[I]
191+
A[I] = (a + b) / 2
192+
B[I] = b - a
193+
end
194+
end
195+
return A, B
196+
end

src/implementations/projections.jl

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ function copy_input(::typeof(project_hermitian), A::AbstractMatrix)
55
end
66
copy_input(::typeof(project_antihermitian), A) = copy_input(project_hermitian, A)
77

8+
copy_input(::typeof(project_isometric), A) = copy_input(left_polar, A)
9+
810
function check_input(::typeof(project_hermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm)
911
LinearAlgebra.checksquare(A)
1012
n = Base.require_one_based_indexing(A)
@@ -18,6 +20,16 @@ function check_input(::typeof(project_antihermitian!), A::AbstractMatrix, B::Abs
1820
return nothing
1921
end
2022

23+
function check_input(::typeof(project_isometric!), A::AbstractMatrix, W::AbstractMatrix, ::AbstractAlgorithm)
24+
m, n = size(A)
25+
m >= n ||
26+
throw(ArgumentError("input matrix needs at least as many rows as columns"))
27+
@assert W isa AbstractMatrix
28+
@check_size(W, (m, n))
29+
@check_scalar(W, A)
30+
return nothing
31+
end
32+
2133
# Outputs
2234
# -------
2335
function initialize_output(::typeof(project_hermitian!), A::AbstractMatrix, ::NativeBlocked)
@@ -27,15 +39,26 @@ function initialize_output(::typeof(project_antihermitian!), A::AbstractMatrix,
2739
return A
2840
end
2941

42+
function initialize_output(::typeof(project_isometric!), A::AbstractMatrix, ::AbstractAlgorithm)
43+
return similar(A)
44+
end
45+
3046
# Implementation
3147
# --------------
32-
function project_hermitian!(A::AbstractMatrix, B, alg::NativeBlocked)
33-
check_input(project_hermitian!, A, B, alg)
34-
return project_hermitian_native!(A, B, Val(false); alg.kwargs...)
48+
function project_hermitian!(A::AbstractMatrix, Aₕ, alg::NativeBlocked)
49+
check_input(project_hermitian!, A, Aₕ, alg)
50+
return project_hermitian_native!(A, Aₕ, Val(false); alg.kwargs...)
3551
end
36-
function project_antihermitian!(A::AbstractMatrix, B, alg::NativeBlocked)
37-
check_input(project_antihermitian!, A, B, alg)
38-
return project_hermitian_native!(A, B, Val(true); alg.kwargs...)
52+
function project_antihermitian!(A::AbstractMatrix, Aₐ, alg::NativeBlocked)
53+
check_input(project_antihermitian!, A, Aₐ, alg)
54+
return project_hermitian_native!(A, Aₐ, Val(true); alg.kwargs...)
55+
end
56+
57+
function project_isometric!(A::AbstractMatrix, W, alg::AbstractAlgorithm)
58+
check_input(project_isometric!, A, W, alg)
59+
noP = similar(W, (0, 0))
60+
W, _ = left_polar!(A, (W, noP), alg)
61+
return W
3962
end
4063

4164
function project_hermitian_native!(A::AbstractMatrix, B::AbstractMatrix, anti::Val; blocksize = 32)

0 commit comments

Comments
 (0)