Skip to content

Commit a099ef4

Browse files
committed
rework factorizations to be MAK v0.6 compatible
1 parent caee937 commit a099ef4

2 files changed

Lines changed: 139 additions & 240 deletions

File tree

src/linalg/factorizations.jl

Lines changed: 96 additions & 197 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
using MatrixAlgebraKit
22
using MatrixAlgebraKit: AbstractAlgorithm, YALAPACK.BlasMat, Algorithm
33
import MatrixAlgebraKit as MAK
4-
using TensorKit.Factorizations: @check_space, @check_scalar
54

65
# Type piracy for defining the MAK rules on BlockArrays!
76
# -----------------------------------------------------
@@ -16,6 +15,55 @@ function MatrixAlgebraKit.one!(A::BlockBlasMat)
1615
return A
1716
end
1817

18+
for f in
19+
[
20+
:svd_compact, :svd_full, :svd_vals,
21+
:qr_compact, :qr_full, :qr_null,
22+
:lq_compact, :lq_full, :lq_null,
23+
:eig_full, :eig_vals, :eigh_full, :eigh_vals,
24+
:left_polar, :right_polar,
25+
:project_hermitian, :project_antihermitian, :project_isometric,
26+
]
27+
f! = Symbol(f, :!)
28+
@eval MAK.default_algorithm(::typeof($f!), ::Type{T}; kwargs...) where {T <: AbstractBlockTensorMap} =
29+
MAK.default_algorithm($f!, eltype(T); kwargs...)
30+
end
31+
32+
for f! in (
33+
:qr_compact!, :qr_full!, :lq_compact!, :lq_full!,
34+
:eig_full!, :eigh_full!, :svd_compact!, :svd_full!,
35+
:left_polar!, :right_polar!,
36+
)
37+
@eval function MAK.$f!(t::AbstractBlockTensorMap, F, alg::AbstractAlgorithm)
38+
TensorKit.foreachblock(t, F...) do _, (tblock, Fblocks...)
39+
Fblocks′ = MAK.$f!(Array(tblock), alg)
40+
# deal with the case where the output is not in-place
41+
for (b′, b) in zip(Fblocks′, Fblocks)
42+
b === b′ || copy!(b, b′)
43+
end
44+
return nothing
45+
end
46+
return F
47+
end
48+
end
49+
50+
# Handle these separately because single output instead of tuple
51+
for f! in (
52+
:qr_null!, :lq_null!,
53+
:svd_vals!, :eig_vals!, :eigh_vals!,
54+
:project_hermitian!, :project_antihermitian!, :project_isometric!,
55+
)
56+
@eval function MAK.$f!(t::AbstractBlockTensorMap, N, alg::AbstractAlgorithm)
57+
TensorKit.foreachblock(t, N) do _, (tblock, Nblock)
58+
Nblock′ = MAK.$f!(Array(tblock), alg)
59+
# deal with the case where the output is not the same as the input
60+
Nblock === Nblock′ || copy!(Nblock, Nblock′)
61+
return nothing
62+
end
63+
return N
64+
end
65+
end
66+
1967
for f in (
2068
:svd_compact, :svd_full, :svd_vals, :qr_compact, :qr_full, :qr_null,
2169
:lq_compact, :lq_full, :lq_null, :eig_full, :eig_vals, :eigh_full,
@@ -26,60 +74,58 @@ for f in (
2674
@eval MAK.$f!(t::BlockBlasMat, F, alg::MAK.DiagonalAlgorithm) = error("Not diagonal")
2775
end
2876

29-
# disambiguations
30-
for (f!, Alg) in (
31-
(:lq_compact!, :LAPACK_HouseholderLQ), (:lq_full!, :LAPACK_HouseholderLQ), (:lq_null!, :LAPACK_HouseholderLQ),
32-
(:lq_compact!, :LQViaTransposedQR), (:lq_full!, :LQViaTransposedQR), (:lq_null!, :LQViaTransposedQR),
33-
(:qr_compact!, :LAPACK_HouseholderQR), (:qr_full!, :LAPACK_HouseholderQR), (:qr_null!, :LAPACK_HouseholderQR),
34-
(:svd_compact!, :LAPACK_SVDAlgorithm), (:svd_full!, :LAPACK_SVDAlgorithm), (:svd_vals!, :LAPACK_SVDAlgorithm),
35-
(:eig_full!, :LAPACK_EigAlgorithm), (:eig_trunc!, :TruncatedAlgorithm), (:eig_vals!, :LAPACK_EigAlgorithm),
36-
(:eigh_full!, :LAPACK_EighAlgorithm), (:eigh_trunc!, :TruncatedAlgorithm), (:eigh_vals!, :LAPACK_EighAlgorithm),
37-
(:left_polar!, :PolarViaSVD), (:right_polar!, :PolarViaSVD),
38-
)
39-
@eval MAK.$f!(t::BlockBlasMat, F, alg::MAK.$Alg) = $f!(Array(t), alg)
40-
end
41-
42-
const GPU_QRAlgorithm = Union{MAK.CUSOLVER_HouseholderQR, MAK.ROCSOLVER_HouseholderQR}
43-
for f! in (:qr_compact!, :qr_full!, :qr_null!)
44-
@eval MAK.$f!(t::BlockBlasMat, QR, alg::GPU_QRAlgorithm) = error()
77+
# specializations until fixes in base package
78+
function MAK.is_left_isometric(A::BlockMatrix; atol::Real = 0, rtol::Real = MAK.defaulttol(A), norm = LinearAlgebra.norm)
79+
P = A' * A
80+
nP = norm(P) # isapprox would use `rtol * max(norm(P), norm(I))`
81+
for I in MAK.diagind(P)
82+
P[I] -= 1
83+
end
84+
return norm(P) <= max(atol, rtol * nP) # assume that the norm of I is `sqrt(n)`
4585
end
46-
47-
for (f!, Alg) in (
48-
(:eigh_full!, :GPU_EighAlgorithm), (:eigh_vals!, :GPU_EighAlgorithm),
49-
(:eig_full!, :GPU_EigAlgorithm), (:eig_vals!, :GPU_EigAlgorithm),
50-
(:svd_full!, :GPU_SVDAlgorithm), (:svd_compact!, :GPU_SVDAlgorithm), (:svd_vals!, :GPU_SVDAlgorithm),
51-
)
52-
@eval MAK.$f!(t::BlockBlasMat, F, alg::MAK.$Alg) = error()
86+
function MAK.is_right_isometric(A::BlockMatrix; atol::Real = 0, rtol::Real = MAK.defaulttol(A), norm = LinearAlgebra.norm)
87+
P = A * A'
88+
nP = norm(P) # isapprox would use `rtol * max(norm(P), norm(I))`
89+
for I in MAK.diagind(P)
90+
P[I] -= 1
91+
end
92+
return norm(P) <= max(atol, rtol * nP) # assume that the norm of I is `sqrt(n)`
5393
end
5494

55-
56-
for f in (:qr, :lq, :eig, :eigh, :gen_eig, :svd, :polar)
57-
default_f_algorithm = Symbol(:default_, f, :_algorithm)
58-
@eval MAK.$default_f_algorithm(::Type{<:BlockBlasMat{T}}; kwargs...) where {T} =
59-
MAK.$default_f_algorithm(Matrix{T}; kwargs...)
60-
end
95+
# disambiguations
96+
# for (f!, Alg) in (
97+
# (:lq_compact!, :LAPACK_HouseholderLQ), (:lq_full!, :LAPACK_HouseholderLQ), (:lq_null!, :LAPACK_HouseholderLQ),
98+
# (:lq_compact!, :LQViaTransposedQR), (:lq_full!, :LQViaTransposedQR), (:lq_null!, :LQViaTransposedQR),
99+
# (:qr_compact!, :LAPACK_HouseholderQR), (:qr_full!, :LAPACK_HouseholderQR), (:qr_null!, :LAPACK_HouseholderQR),
100+
# (:svd_compact!, :LAPACK_SVDAlgorithm), (:svd_full!, :LAPACK_SVDAlgorithm), (:svd_vals!, :LAPACK_SVDAlgorithm),
101+
# (:eig_full!, :LAPACK_EigAlgorithm), (:eig_trunc!, :TruncatedAlgorithm), (:eig_vals!, :LAPACK_EigAlgorithm),
102+
# (:eigh_full!, :LAPACK_EighAlgorithm), (:eigh_trunc!, :TruncatedAlgorithm), (:eigh_vals!, :LAPACK_EighAlgorithm),
103+
# (:left_polar!, :PolarViaSVD), (:right_polar!, :PolarViaSVD),
104+
# )
105+
# @eval MAK.$f!(t::BlockBlasMat, F, alg::MAK.$Alg) = $f!(Array(t), alg)
106+
# end
107+
#
108+
# const GPU_QRAlgorithm = Union{MAK.CUSOLVER_HouseholderQR, MAK.ROCSOLVER_HouseholderQR}
109+
# for f! in (:qr_compact!, :qr_full!, :qr_null!)
110+
# @eval MAK.$f!(t::BlockBlasMat, QR, alg::GPU_QRAlgorithm) = error()
111+
# end
112+
#
113+
# for (f!, Alg) in (
114+
# (:eigh_full!, :GPU_EighAlgorithm), (:eigh_vals!, :GPU_EighAlgorithm),
115+
# (:eig_full!, :GPU_EigAlgorithm), (:eig_vals!, :GPU_EigAlgorithm),
116+
# (:svd_full!, :GPU_SVDAlgorithm), (:svd_compact!, :GPU_SVDAlgorithm), (:svd_vals!, :GPU_SVDAlgorithm),
117+
# )
118+
# @eval MAK.$f!(t::BlockBlasMat, F, alg::MAK.$Alg) = error()
119+
# end
120+
121+
122+
# for f in (:qr, :lq, :eig, :eigh, :gen_eig, :svd, :polar)
123+
# default_f_algorithm = Symbol(:default_, f, :_algorithm)
124+
# @eval MAK.$default_f_algorithm(::Type{<:BlockBlasMat{T}}; kwargs...) where {T} =
125+
# MAK.$default_f_algorithm(Matrix{T}; kwargs...)
126+
# end
61127

62128
# Make sure sparse blocktensormaps have dense outputs
63-
function MAK.check_input(::typeof(qr_full!), t::AbstractBlockTensorMap, QR, ::AbstractAlgorithm)
64-
Q, R = QR
65-
66-
# type checks
67-
@assert Q isa AbstractTensorMap
68-
@assert R isa AbstractTensorMap
69-
70-
# scalartype checks
71-
@check_scalar Q t
72-
@check_scalar R t
73-
74-
# space checks
75-
V_Q = (fuse(codomain(t)))
76-
@check_space(Q, codomain(t) V_Q)
77-
@check_space(R, V_Q domain(t))
78-
79-
return nothing
80-
end
81-
MAK.check_input(::typeof(qr_full!), t::AbstractBlockTensorMap, QR, ::DiagonalAlgorithm) = error()
82-
83129
function MAK.initialize_output(::typeof(qr_full!), t::AbstractBlockTensorMap, ::AbstractAlgorithm)
84130
V_Q = (fuse(codomain(t)))
85131
Q = dense_similar(t, codomain(t) V_Q)
@@ -99,26 +145,6 @@ function MAK.initialize_output(::typeof(qr_null!), t::AbstractBlockTensorMap, ::
99145
return N
100146
end
101147

102-
function MAK.check_input(::typeof(lq_full!), t::AbstractBlockTensorMap, LQ, ::AbstractAlgorithm)
103-
L, Q = LQ
104-
105-
# type checks
106-
@assert L isa AbstractTensorMap
107-
@assert Q isa AbstractTensorMap
108-
109-
# scalartype checks
110-
@check_scalar L t
111-
@check_scalar Q t
112-
113-
# space checks
114-
V_Q = (fuse(domain(t)))
115-
@check_space(L, codomain(t) V_Q)
116-
@check_space(Q, V_Q domain(t))
117-
118-
return nothing
119-
end
120-
MAK.check_input(::typeof(lq_full!), t::AbstractBlockTensorMap, LQ, ::DiagonalAlgorithm) = error()
121-
122148
function MAK.initialize_output(::typeof(lq_full!), t::AbstractBlockTensorMap, ::AbstractAlgorithm)
123149
V_Q = (fuse(domain(t)))
124150
L = dense_similar(t, codomain(t) V_Q)
@@ -138,99 +164,18 @@ function MAK.initialize_output(::typeof(lq_null!), t::AbstractBlockTensorMap, ::
138164
return N
139165
end
140166

141-
function MAK.check_input(::typeof(MAK.left_orth_polar!), t::AbstractBlockTensorMap, WP, ::AbstractAlgorithm)
142-
codomain(t) domain(t) ||
143-
throw(ArgumentError("Polar decomposition requires `codomain(t) ≿ domain(t)`"))
144-
145-
W, P = WP
146-
@assert W isa AbstractTensorMap
147-
@assert P isa AbstractTensorMap
148-
149-
# scalartype checks
150-
@check_scalar W t
151-
@check_scalar P t
152-
153-
# space checks
154-
VW = (fuse(domain(t)))
155-
@check_space(W, codomain(t) VW)
156-
@check_space(P, VW domain(t))
157-
158-
return nothing
159-
end
160167
function MAK.initialize_output(::typeof(left_polar!), t::AbstractBlockTensorMap, ::AbstractAlgorithm)
161168
W = dense_similar(t, space(t))
162169
P = dense_similar(t, domain(t) domain(t))
163170
return W, P
164171
end
165172

166-
function MAK.check_input(::typeof(MAK.right_orth_polar!), t::AbstractBlockTensorMap, PWᴴ, ::AbstractAlgorithm)
167-
codomain(t) domain(t) ||
168-
throw(ArgumentError("Polar decomposition requires `domain(t) ≿ codomain(t)`"))
169-
170-
P, Wᴴ = PWᴴ
171-
@assert P isa AbstractTensorMap
172-
@assert Wᴴ isa AbstractTensorMap
173-
174-
# scalartype checks
175-
@check_scalar P t
176-
@check_scalar Wᴴ t
177-
178-
# space checks
179-
VW = (fuse(codomain(t)))
180-
@check_space(P, codomain(t) VW)
181-
@check_space(Wᴴ, VW domain(t))
182-
183-
return nothing
184-
end
185173
function MAK.initialize_output(::typeof(right_polar!), t::AbstractBlockTensorMap, ::AbstractAlgorithm)
186174
P = dense_similar(t, codomain(t) codomain(t))
187175
Wᴴ = dense_similar(t, space(t))
188176
return P, Wᴴ
189177
end
190178

191-
function MAK.initialize_output(::typeof(left_null!), t::AbstractBlockTensorMap, ::AbstractAlgorithm)
192-
V_Q = infimum(fuse(codomain(t)), fuse(domain(t)))
193-
V_N = (fuse(codomain(t)), V_Q)
194-
return dense_similar(t, codomain(t) V_N)
195-
end
196-
function MAK.initialize_output(::typeof(right_null!), t::AbstractBlockTensorMap, ::AbstractAlgorithm)
197-
V_Q = infimum(fuse(codomain(t)), fuse(domain(t)))
198-
V_N = (fuse(domain(t)), V_Q)
199-
return dense_similar(t, V_N domain(t))
200-
end
201-
202-
function MAK.check_input(::typeof(eigh_full!), t::AbstractBlockTensorMap, DV, ::AbstractAlgorithm)
203-
domain(t) == codomain(t) ||
204-
throw(ArgumentError("Eigenvalue decomposition requires square input tensor"))
205-
206-
D, V = DV
207-
208-
# type checks
209-
@assert D isa DiagonalTensorMap
210-
@assert V isa AbstractTensorMap
211-
212-
# scalartype checks
213-
@check_scalar D t real
214-
@check_scalar V t
215-
216-
# space checks
217-
V_D = (fuse(domain(t)))
218-
@check_space(D, V_D V_D)
219-
@check_space(V, codomain(t) V_D)
220-
221-
return nothing
222-
end
223-
MAK.check_input(::typeof(eigh_full!), t::AbstractBlockTensorMap, DV, ::DiagonalAlgorithm) = error()
224-
225-
function MAK.check_input(::typeof(eigh_vals!), t::AbstractBlockTensorMap, D, ::AbstractAlgorithm)
226-
@check_scalar D t real
227-
@assert D isa DiagonalTensorMap
228-
V_D = (fuse(domain(t)))
229-
@check_space(D, V_D V_D)
230-
return nothing
231-
end
232-
MAK.check_input(::typeof(eigh_vals!), t::AbstractBlockTensorMap, D, ::DiagonalAlgorithm) = error()
233-
234179
function MAK.initialize_output(::typeof(eigh_full!), t::AbstractBlockTensorMap, ::AbstractAlgorithm)
235180
V_D = (fuse(domain(t)))
236181
T = real(scalartype(t))
@@ -239,30 +184,6 @@ function MAK.initialize_output(::typeof(eigh_full!), t::AbstractBlockTensorMap,
239184
return D, V
240185
end
241186

242-
243-
function MAK.check_input(::typeof(eig_full!), t::AbstractBlockTensorMap, DV, ::AbstractAlgorithm)
244-
domain(t) == codomain(t) ||
245-
throw(ArgumentError("Eigenvalue decomposition requires square input tensor"))
246-
247-
D, V = DV
248-
249-
# type checks
250-
@assert D isa DiagonalTensorMap
251-
@assert V isa AbstractTensorMap
252-
253-
# scalartype checks
254-
@check_scalar D t complex
255-
@check_scalar V t complex
256-
257-
# space checks
258-
V_D = (fuse(domain(t)))
259-
@check_space(D, V_D V_D)
260-
@check_space(V, codomain(t) V_D)
261-
262-
return nothing
263-
end
264-
MAK.check_input(::typeof(eig_full!), t::AbstractBlockTensorMap, DV, ::DiagonalAlgorithm) = error()
265-
266187
function MAK.initialize_output(::typeof(eig_full!), t::AbstractBlockTensorMap, ::AbstractAlgorithm)
267188
V_D = (fuse(domain(t)))
268189
Tc = complex(scalartype(t))
@@ -271,28 +192,6 @@ function MAK.initialize_output(::typeof(eig_full!), t::AbstractBlockTensorMap, :
271192
return D, V
272193
end
273194

274-
function MAK.check_input(::typeof(svd_full!), t::AbstractBlockTensorMap, USVᴴ, ::AbstractAlgorithm)
275-
U, S, Vᴴ = USVᴴ
276-
277-
# type checks
278-
@assert U isa AbstractTensorMap
279-
@assert S isa AbstractTensorMap
280-
@assert Vᴴ isa AbstractTensorMap
281-
282-
# scalartype checks
283-
@check_scalar U t
284-
@check_scalar S t real
285-
@check_scalar Vᴴ t
286-
287-
# space checks
288-
V_cod = (fuse(codomain(t)))
289-
V_dom = (fuse(domain(t)))
290-
@check_space(U, codomain(t) V_cod)
291-
@check_space(S, V_cod V_dom)
292-
@check_space(Vᴴ, V_dom domain(t))
293-
294-
return nothing
295-
end
296195
function MAK.initialize_output(::typeof(svd_full!), t::AbstractBlockTensorMap, ::AbstractAlgorithm)
297196
V_cod = (fuse(codomain(t)))
298197
V_dom = (fuse(domain(t)))

0 commit comments

Comments
 (0)