Skip to content

Commit 09715df

Browse files
authored
some more updates on algorithms (#196)
1 parent 6cc7924 commit 09715df

8 files changed

Lines changed: 114 additions & 112 deletions

File tree

ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ function heev!(::GLA, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix;
5353
return Dd, V
5454
end
5555

56-
function MatrixAlgebraKit.householder_qr!(
56+
function MatrixAlgebraKit.qr_householder!(
5757
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
5858
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
5959
)
@@ -97,7 +97,7 @@ function MatrixAlgebraKit.householder_qr!(
9797
return Q, R
9898
end
9999

100-
function MatrixAlgebraKit.householder_qr_null!(
100+
function MatrixAlgebraKit.qr_null_householder!(
101101
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, N::AbstractMatrix;
102102
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
103103
)

src/implementations/eig.jl

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -103,44 +103,45 @@ for f! in (:geev!, :geevx!)
103103
end
104104

105105
# driver dispatch
106-
@inline qr_iteration_eig_full!(A, Dd, V; driver::Driver = DefaultDriver(), kwargs...) =
107-
qr_iteration_eig_full!(driver, A, Dd, V; kwargs...)
108-
@inline qr_iteration_eig_vals!(A, D, V; driver::Driver = DefaultDriver(), kwargs...) =
109-
qr_iteration_eig_vals!(driver, A, D, V; kwargs...)
106+
@inline eig_full_qr_iteration!(A, DV; driver::Driver = DefaultDriver(), kwargs...) =
107+
eig_full_qr_iteration!(driver, A, DV; kwargs...)
108+
@inline eig_vals_qr_iteration!(A, D; driver::Driver = DefaultDriver(), kwargs...) =
109+
eig_vals_qr_iteration!(driver, A, D; kwargs...)
110110

111-
@inline qr_iteration_eig_full!(::DefaultDriver, A, Dd, V; kwargs...) =
112-
qr_iteration_eig_full!(default_driver(QRIteration, A), A, Dd, V; kwargs...)
113-
@inline qr_iteration_eig_vals!(::DefaultDriver, A, D, V; kwargs...) =
114-
qr_iteration_eig_vals!(default_driver(QRIteration, A), A, D, V; kwargs...)
111+
@inline eig_full_qr_iteration!(::DefaultDriver, A, DV; kwargs...) =
112+
eig_full_qr_iteration!(default_driver(QRIteration, A), A, DV; kwargs...)
113+
@inline eig_vals_qr_iteration!(::DefaultDriver, A, D; kwargs...) =
114+
eig_vals_qr_iteration!(default_driver(QRIteration, A), A, D; kwargs...)
115115

116116
# Implementation
117-
function qr_iteration_eig_full!(
118-
driver::Driver, A, Dd, V;
117+
function eig_full_qr_iteration!(
118+
driver::Driver, A, DV;
119119
fixgauge::Bool = default_fixgauge(), scale::Bool = true, permute::Bool = true
120120
)
121+
D, V = DV
122+
Dd = diagview(D)
121123
(scale & permute) ? geev!(driver, A, Dd, V) : geevx!(driver, A, Dd, V; scale, permute)
122124
fixgauge && gaugefix!(eig_full!, V)
123-
return Dd, V
125+
return DV
124126
end
125-
function qr_iteration_eig_vals!(
126-
driver::Driver, A, D, V;
127+
function eig_vals_qr_iteration!(
128+
driver::Driver, A, D;
127129
fixgauge::Bool = default_fixgauge(), scale::Bool = true, permute::Bool = true
128130
)
131+
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
129132
(scale & permute) ? geev!(driver, A, D, V) : geevx!(driver, A, D, V; scale, permute)
130133
return D
131134
end
132135

133136
# Top-level QRIteration dispatch
134137
function eig_full!(A::AbstractMatrix, DV, alg::QRIteration)
135138
check_input(eig_full!, A, DV, alg)
136-
D, V = DV
137-
qr_iteration_eig_full!(A, diagview(D), V; alg.kwargs...)
138-
return D, V
139+
eig_full_qr_iteration!(A, DV; alg.kwargs...)
140+
return DV
139141
end
140142
function eig_vals!(A::AbstractMatrix, D, alg::QRIteration)
141143
check_input(eig_vals!, A, D, alg)
142-
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
143-
qr_iteration_eig_vals!(A, D, V; alg.kwargs...)
144+
eig_vals_qr_iteration!(A, D; alg.kwargs...)
144145
return D
145146
end
146147

src/implementations/eigh.jl

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -115,46 +115,47 @@ for (f, f_lapack!, Alg) in (
115115
(:bisection, :heevx!, :Bisection),
116116
(:jacobi, :heevj!, :Jacobi),
117117
)
118-
f_eigh_full! = Symbol(f, :_eigh_full!)
119-
f_eigh_vals! = Symbol(f, :_eigh_vals!)
118+
eigh_full_f! = Symbol(:eigh_full_, f, :!)
119+
eigh_vals_f! = Symbol(:eigh_vals_, f, :!)
120120

121121
# MatrixAlgebraKit wrappers
122122
@eval begin
123123
function eigh_full!(A::AbstractMatrix, DV, alg::$Alg)
124124
check_input(eigh_full!, A, DV, alg)
125-
D, V = DV
126-
Dd, V = $f_eigh_full!(A, D.diag, V; alg.kwargs...)
127-
return D, V
125+
$eigh_full_f!(A, DV; alg.kwargs...)
126+
return DV
128127
end
129128
function eigh_vals!(A::AbstractMatrix, D, alg::$Alg)
130129
check_input(eigh_vals!, A, D, alg)
131-
V = similar(A, (size(A, 1), 0))
132-
$f_eigh_vals!(A, D, V; alg.kwargs...)
130+
$eigh_vals_f!(A, D; alg.kwargs...)
133131
return D
134132
end
135133
end
136134

137135
# driver dispatch
138136
@eval begin
139-
@inline $f_eigh_full!(A, Dd, V; driver::Driver = DefaultDriver(), kwargs...) =
140-
$f_eigh_full!(driver, A, Dd, V; kwargs...)
141-
@inline $f_eigh_vals!(A, D, V; driver::Driver = DefaultDriver(), kwargs...) =
142-
$f_eigh_vals!(driver, A, D, V; kwargs...)
137+
@inline $eigh_full_f!(A, DV; driver::Driver = DefaultDriver(), kwargs...) =
138+
$eigh_full_f!(driver, A, DV; kwargs...)
139+
@inline $eigh_vals_f!(A, D; driver::Driver = DefaultDriver(), kwargs...) =
140+
$eigh_vals_f!(driver, A, D; kwargs...)
143141

144-
@inline $f_eigh_full!(::DefaultDriver, A, Dd, V; kwargs...) =
145-
$f_eigh_full!(default_driver($Alg, A), A, Dd, V; kwargs...)
146-
@inline $f_eigh_vals!(::DefaultDriver, A, D, V; kwargs...) =
147-
$f_eigh_vals!(default_driver($Alg, A), A, D, V; kwargs...)
142+
@inline $eigh_full_f!(::DefaultDriver, A, DV; kwargs...) =
143+
$eigh_full_f!(default_driver($Alg, A), A, DV; kwargs...)
144+
@inline $eigh_vals_f!(::DefaultDriver, A, D; kwargs...) =
145+
$eigh_vals_f!(default_driver($Alg, A), A, D; kwargs...)
148146
end
149147

150148
# Implementation
151149
@eval begin
152-
function $f_eigh_full!(driver::Driver, A, Dd, V; fixgauge::Bool = default_fixgauge(), kwargs...)
150+
function $eigh_full_f!(driver::Driver, A, DV; fixgauge::Bool = default_fixgauge(), kwargs...)
151+
D, V = DV
152+
Dd = diagview(D)
153153
$f_lapack!(driver, A, Dd, V; kwargs...)
154154
fixgauge && gaugefix!(eigh_full!, V)
155-
return Dd, V
155+
return DV
156156
end
157-
function $f_eigh_vals!(driver::Driver, A, D, V; fixgauge::Bool = default_fixgauge(), kwargs...)
157+
function $eigh_vals_f!(driver::Driver, A, D; fixgauge::Bool = default_fixgauge(), kwargs...)
158+
V = similar(A, (size(A, 1), 0))
158159
$f_lapack!(driver, A, D, V; kwargs...)
159160
return D
160161
end

src/implementations/lq.jl

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,15 @@ end
105105
# -----------
106106
function lq_full!(A, LQ, alg::Householder)
107107
check_input(lq_full!, A, LQ, alg)
108-
return householder_lq!(A, LQ...; alg.kwargs...)
108+
return lq_householder!(A, LQ...; alg.kwargs...)
109109
end
110110
function lq_compact!(A, LQ, alg::Householder)
111111
check_input(lq_compact!, A, LQ, alg)
112-
return householder_lq!(A, LQ...; alg.kwargs...)
112+
return lq_householder!(A, LQ...; alg.kwargs...)
113113
end
114114
function lq_null!(A, Nᴴ, alg::Householder)
115115
check_input(lq_null!, A, Nᴴ, alg)
116-
return householder_lq_null!(A, Nᴴ; alg.kwargs...)
116+
return lq_null_householder!(A, Nᴴ; alg.kwargs...)
117117
end
118118

119119
# dispatch helpers
@@ -123,13 +123,13 @@ for f in (:gelqt!, :gemlqt!, :gelqf!, :unglq!, :unmlq!)
123123
end
124124
end
125125

126-
@inline householder_lq!(A, L, Q; driver::Driver = DefaultDriver(), kwargs...) =
127-
householder_lq!(driver, A, L, Q; kwargs...)
128-
householder_lq!(::DefaultDriver, A, L, Q; kwargs...) =
129-
householder_lq!(default_driver(Householder, A), A, L, Q; kwargs...)
130-
householder_lq!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, L, Q; kwargs...) =
126+
@inline lq_householder!(A, L, Q; driver::Driver = DefaultDriver(), kwargs...) =
127+
lq_householder!(driver, A, L, Q; kwargs...)
128+
lq_householder!(::DefaultDriver, A, L, Q; kwargs...) =
129+
lq_householder!(default_driver(Householder, A), A, L, Q; kwargs...)
130+
lq_householder!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, L, Q; kwargs...) =
131131
lq_via_qr!(A, L, Q, Householder(; driver, kwargs...))
132-
function householder_lq!(
132+
function lq_householder!(
133133
driver::LAPACK, A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
134134
positive = true, pivoted = false, blocksize::Int = 0
135135
)
@@ -186,7 +186,7 @@ function householder_lq!(
186186
end
187187
return L, Q
188188
end
189-
function householder_lq!(
189+
function lq_householder!(
190190
driver::Native, A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
191191
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
192192
)
@@ -229,13 +229,13 @@ function householder_lq!(
229229
return L, Q
230230
end
231231

232-
@inline householder_lq_null!(A, Nᴴ; driver::Driver = DefaultDriver(), kwargs...) =
233-
householder_lq_null!(driver, A, Nᴴ; kwargs...)
234-
householder_lq_null!(::DefaultDriver, A, Nᴴ; kwargs...) =
235-
householder_lq_null!(default_driver(Householder, A), A, Nᴴ; kwargs...)
236-
householder_lq_null!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, Nᴴ; kwargs...) =
232+
@inline lq_null_householder!(A, Nᴴ; driver::Driver = DefaultDriver(), kwargs...) =
233+
lq_null_householder!(driver, A, Nᴴ; kwargs...)
234+
lq_null_householder!(::DefaultDriver, A, Nᴴ; kwargs...) =
235+
lq_null_householder!(default_driver(Householder, A), A, Nᴴ; kwargs...)
236+
lq_null_householder!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, Nᴴ; kwargs...) =
237237
lq_null_via_qr!(A, Nᴴ, Householder(; driver, kwargs...))
238-
function householder_lq_null!(
238+
function lq_null_householder!(
239239
driver::LAPACK, A::AbstractMatrix, Nᴴ::AbstractMatrix;
240240
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
241241
)
@@ -260,7 +260,7 @@ function householder_lq_null!(
260260
end
261261
return Nᴴ
262262
end
263-
function householder_lq_null!(
263+
function lq_null_householder!(
264264
driver::Native, A::AbstractMatrix, Nᴴ::AbstractMatrix;
265265
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
266266
)
@@ -343,21 +343,21 @@ end
343343
function lq_full!(A::AbstractMatrix, LQ, alg::DiagonalAlgorithm)
344344
check_input(lq_full!, A, LQ, alg)
345345
L, Q = LQ
346-
_diagonal_lq!(A, L, Q; alg.kwargs...)
346+
lq_diagonal!(A, L, Q; alg.kwargs...)
347347
return L, Q
348348
end
349349
function lq_compact!(A::AbstractMatrix, LQ, alg::DiagonalAlgorithm)
350350
check_input(lq_compact!, A, LQ, alg)
351351
L, Q = LQ
352-
_diagonal_lq!(A, L, Q; alg.kwargs...)
352+
lq_diagonal!(A, L, Q; alg.kwargs...)
353353
return L, Q
354354
end
355355
function lq_null!(A::AbstractMatrix, N, alg::DiagonalAlgorithm)
356356
check_input(lq_null!, A, N, alg)
357-
return _diagonal_lq_null!(A, N; alg.kwargs...)
357+
return lq_null_diagonal!(A, N; alg.kwargs...)
358358
end
359359

360-
function _diagonal_lq!(
360+
function lq_diagonal!(
361361
A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix; positive::Bool = true
362362
)
363363
# note: Ad and Qd might share memory here so order of operations is important
@@ -374,7 +374,7 @@ function _diagonal_lq!(
374374
return L, Q
375375
end
376376

377-
_diagonal_lq_null!(A::AbstractMatrix, N; positive::Bool = true) = N
377+
lq_null_diagonal!(A::AbstractMatrix, N; positive::Bool = true) = N
378378

379379
# Deprecations
380380
# ------------

src/implementations/polar.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ function left_polar!(A::AbstractMatrix, WP, alg::PolarNewton)
100100
check_input(left_polar!, A, WP, alg)
101101
W, P = WP
102102
if isempty(P)
103-
W = _left_polarnewton!(A, W, P; alg.kwargs...)
103+
W = left_polar_newton!(A, W, P; alg.kwargs...)
104104
return W, P
105105
else
106-
W = _left_polarnewton!(copy(A), W, P; alg.kwargs...)
106+
W = left_polar_newton!(copy(A), W, P; alg.kwargs...)
107107
# we still need `A` to compute `P`
108108
P = project_hermitian!(mul!(P, W', A))
109109
return W, P
@@ -114,18 +114,18 @@ function right_polar!(A::AbstractMatrix, PWᴴ, alg::PolarNewton)
114114
check_input(right_polar!, A, PWᴴ, alg)
115115
P, Wᴴ = PWᴴ
116116
if isempty(P)
117-
Wᴴ = _right_polarnewton!(A, Wᴴ, P; alg.kwargs...)
117+
Wᴴ = right_polar_newton!(A, Wᴴ, P; alg.kwargs...)
118118
return P, Wᴴ
119119
else
120-
Wᴴ = _right_polarnewton!(copy(A), Wᴴ, P; alg.kwargs...)
120+
Wᴴ = right_polar_newton!(copy(A), Wᴴ, P; alg.kwargs...)
121121
# we still need `A` to compute `P`
122122
P = project_hermitian!(mul!(P, A, Wᴴ'))
123123
return P, Wᴴ
124124
end
125125
end
126126

127127
# these methods only compute W and destroy A in the process
128-
function _left_polarnewton!(A::AbstractMatrix, W, P = similar(A, (0, 0)); tol = defaulttol(A), maxiter = 10)
128+
function left_polar_newton!(A::AbstractMatrix, W, P = similar(A, (0, 0)); tol = defaulttol(A), maxiter = 10)
129129
m, n = size(A) # we must have m >= n
130130
Rᴴinv = isempty(P) ? similar(P, (n, n)) : P # use P as workspace when available
131131
if m > n # initial QR
@@ -165,7 +165,7 @@ function _left_polarnewton!(A::AbstractMatrix, W, P = similar(A, (0, 0)); tol =
165165
return W
166166
end
167167

168-
function _right_polarnewton!(A::AbstractMatrix, Wᴴ, P = similar(A, (0, 0)); tol = defaulttol(A), maxiter = 10)
168+
function right_polar_newton!(A::AbstractMatrix, Wᴴ, P = similar(A, (0, 0)); tol = defaulttol(A), maxiter = 10)
169169
m, n = size(A) # we must have m <= n
170170
Lᴴinv = isempty(P) ? similar(P, (m, m)) : P # use P as workspace when available
171171
if m < n # initial QR

src/implementations/qr.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,15 @@ end
105105
# -----------
106106
function qr_full!(A, QR, alg::Householder)
107107
check_input(qr_full!, A, QR, alg)
108-
return householder_qr!(A, QR...; alg.kwargs...)
108+
return qr_householder!(A, QR...; alg.kwargs...)
109109
end
110110
function qr_compact!(A, QR, alg::Householder)
111111
check_input(qr_compact!, A, QR, alg)
112-
return householder_qr!(A, QR...; alg.kwargs...)
112+
return qr_householder!(A, QR...; alg.kwargs...)
113113
end
114114
function qr_null!(A, N, alg::Householder)
115115
check_input(qr_null!, A, N, alg)
116-
return householder_qr_null!(A, N; alg.kwargs...)
116+
return qr_null_householder!(A, N; alg.kwargs...)
117117
end
118118

119119

@@ -125,11 +125,11 @@ for f in (:geqrt!, :gemqrt!, :geqp3!, :geqrf!, :ungqr!, :unmqr!)
125125
end
126126
end
127127

128-
@inline householder_qr!(A, Q, R; driver::Driver = DefaultDriver(), kwargs...) =
129-
householder_qr!(driver, A, Q, R; kwargs...)
130-
householder_qr!(::DefaultDriver, A, Q, R; kwargs...) =
131-
householder_qr!(default_driver(Householder, A), A, Q, R; kwargs...)
132-
function householder_qr!(
128+
@inline qr_householder!(A, Q, R; driver::Driver = DefaultDriver(), kwargs...) =
129+
qr_householder!(driver, A, Q, R; kwargs...)
130+
qr_householder!(::DefaultDriver, A, Q, R; kwargs...) =
131+
qr_householder!(default_driver(Householder, A), A, Q, R; kwargs...)
132+
function qr_householder!(
133133
driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
134134
positive::Bool = true, pivoted::Bool = false,
135135
blocksize::Int = 0
@@ -213,7 +213,7 @@ function householder_qr!(
213213
end
214214
return Q, R
215215
end
216-
function householder_qr!(
216+
function qr_householder!(
217217
driver::Native, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
218218
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
219219
)
@@ -256,11 +256,11 @@ function householder_qr!(
256256
return Q, R
257257
end
258258

259-
@inline householder_qr_null!(A, N; driver::Driver = DefaultDriver(), kwargs...) =
260-
householder_qr_null!(driver, A, N; kwargs...)
261-
householder_qr_null!(::DefaultDriver, A, N; kwargs...) =
262-
householder_qr_null!(default_driver(Householder, A), A, N; kwargs...)
263-
function householder_qr_null!(
259+
@inline qr_null_householder!(A, N; driver::Driver = DefaultDriver(), kwargs...) =
260+
qr_null_householder!(driver, A, N; kwargs...)
261+
qr_null_householder!(::DefaultDriver, A, N; kwargs...) =
262+
qr_null_householder!(default_driver(Householder, A), A, N; kwargs...)
263+
function qr_null_householder!(
264264
driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, N::AbstractMatrix;
265265
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
266266
)
@@ -288,7 +288,7 @@ function householder_qr_null!(
288288
end
289289
return N
290290
end
291-
function householder_qr_null!(
291+
function qr_null_householder!(
292292
driver::Native, A::AbstractMatrix, N::AbstractMatrix;
293293
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
294294
)

0 commit comments

Comments
 (0)