11using MatrixAlgebraKit
22using MatrixAlgebraKit: AbstractAlgorithm, YALAPACK. BlasMat, Algorithm
33import MatrixAlgebraKit as MAK
4- using TensorKit. Factorizations: @check_space , @check_scalar
54
65# Type piracy for defining the MAK rules on BlockArrays!
76# -----------------------------------------------------
@@ -16,6 +15,55 @@ function MatrixAlgebraKit.one!(A::BlockBlasMat)
1615 return A
1716end
1817
18+ for f in
19+ [
20+ :svd_compact , :svd_full , :svd_vals ,
21+ :qr_compact , :qr_full , :qr_null ,
22+ :lq_compact , :lq_full , :lq_null ,
23+ :eig_full , :eig_vals , :eigh_full , :eigh_vals ,
24+ :left_polar , :right_polar ,
25+ :project_hermitian , :project_antihermitian , :project_isometric ,
26+ ]
27+ f! = Symbol (f, :! )
28+ @eval MAK. default_algorithm (:: typeof ($ f!), :: Type{T} ; kwargs... ) where {T <: AbstractBlockTensorMap } =
29+ MAK. default_algorithm ($ f!, eltype (T); kwargs... )
30+ end
31+
32+ for f! in (
33+ :qr_compact! , :qr_full! , :lq_compact! , :lq_full! ,
34+ :eig_full! , :eigh_full! , :svd_compact! , :svd_full! ,
35+ :left_polar! , :right_polar! ,
36+ )
37+ @eval function MAK. $f! (t:: AbstractBlockTensorMap , F, alg:: AbstractAlgorithm )
38+ TensorKit. foreachblock (t, F... ) do _, (tblock, Fblocks... )
39+ Fblocks′ = MAK.$ f! (Array (tblock), alg)
40+ # deal with the case where the output is not in-place
41+ for (b′, b) in zip (Fblocks′, Fblocks)
42+ b === b′ || copy! (b, b′)
43+ end
44+ return nothing
45+ end
46+ return F
47+ end
48+ end
49+
50+ # Handle these separately because single output instead of tuple
51+ for f! in (
52+ :qr_null! , :lq_null! ,
53+ :svd_vals! , :eig_vals! , :eigh_vals! ,
54+ :project_hermitian! , :project_antihermitian! , :project_isometric! ,
55+ )
56+ @eval function MAK. $f! (t:: AbstractBlockTensorMap , N, alg:: AbstractAlgorithm )
57+ TensorKit. foreachblock (t, N) do _, (tblock, Nblock)
58+ Nblock′ = MAK.$ f! (Array (tblock), alg)
59+ # deal with the case where the output is not the same as the input
60+ Nblock === Nblock′ || copy! (Nblock, Nblock′)
61+ return nothing
62+ end
63+ return N
64+ end
65+ end
66+
1967for f in (
2068 :svd_compact , :svd_full , :svd_vals , :qr_compact , :qr_full , :qr_null ,
2169 :lq_compact , :lq_full , :lq_null , :eig_full , :eig_vals , :eigh_full ,
@@ -26,60 +74,58 @@ for f in (
2674 @eval MAK.$ f! (t:: BlockBlasMat , F, alg:: MAK.DiagonalAlgorithm ) = error (" Not diagonal" )
2775end
2876
29- # disambiguations
30- for (f!, Alg) in (
31- (:lq_compact! , :LAPACK_HouseholderLQ ), (:lq_full! , :LAPACK_HouseholderLQ ), (:lq_null! , :LAPACK_HouseholderLQ ),
32- (:lq_compact! , :LQViaTransposedQR ), (:lq_full! , :LQViaTransposedQR ), (:lq_null! , :LQViaTransposedQR ),
33- (:qr_compact! , :LAPACK_HouseholderQR ), (:qr_full! , :LAPACK_HouseholderQR ), (:qr_null! , :LAPACK_HouseholderQR ),
34- (:svd_compact! , :LAPACK_SVDAlgorithm ), (:svd_full! , :LAPACK_SVDAlgorithm ), (:svd_vals! , :LAPACK_SVDAlgorithm ),
35- (:eig_full! , :LAPACK_EigAlgorithm ), (:eig_trunc! , :TruncatedAlgorithm ), (:eig_vals! , :LAPACK_EigAlgorithm ),
36- (:eigh_full! , :LAPACK_EighAlgorithm ), (:eigh_trunc! , :TruncatedAlgorithm ), (:eigh_vals! , :LAPACK_EighAlgorithm ),
37- (:left_polar! , :PolarViaSVD ), (:right_polar! , :PolarViaSVD ),
38- )
39- @eval MAK.$ f! (t:: BlockBlasMat , F, alg:: MAK. $ Alg) = $ f! (Array (t), alg)
40- end
41-
42- const GPU_QRAlgorithm = Union{MAK. CUSOLVER_HouseholderQR, MAK. ROCSOLVER_HouseholderQR}
43- for f! in (:qr_compact! , :qr_full! , :qr_null! )
44- @eval MAK.$ f! (t:: BlockBlasMat , QR, alg:: GPU_QRAlgorithm ) = error ()
77+ # specializations until fixes in base package
78+ function MAK. is_left_isometric (A:: BlockMatrix ; atol:: Real = 0 , rtol:: Real = MAK. defaulttol (A), norm = LinearAlgebra. norm)
79+ P = A' * A
80+ nP = norm (P) # isapprox would use `rtol * max(norm(P), norm(I))`
81+ for I in MAK. diagind (P)
82+ P[I] -= 1
83+ end
84+ return norm (P) <= max (atol, rtol * nP) # assume that the norm of I is `sqrt(n)`
4585end
46-
47- for (f!, Alg) in (
48- ( :eigh_full! , :GPU_EighAlgorithm ), ( :eigh_vals! , :GPU_EighAlgorithm ),
49- ( :eig_full! , :GPU_EigAlgorithm ), ( :eig_vals! , :GPU_EigAlgorithm ),
50- ( :svd_full! , :GPU_SVDAlgorithm ), ( :svd_compact! , :GPU_SVDAlgorithm ), ( :svd_vals! , :GPU_SVDAlgorithm ),
51- )
52- @eval MAK. $ f! (t :: BlockBlasMat , F, alg :: MAK. $ Alg) = error ()
86+ function MAK . is_right_isometric (A :: BlockMatrix ; atol :: Real = 0 , rtol :: Real = MAK . defaulttol (A), norm = LinearAlgebra . norm)
87+ P = A * A '
88+ nP = norm (P) # isapprox would use `rtol * max(norm(P), norm(I))`
89+ for I in MAK . diagind (P)
90+ P[I] -= 1
91+ end
92+ return norm (P) <= max (atol, rtol * nP) # assume that the norm of I is `sqrt(n)`
5393end
5494
55-
56- for f in (:qr , :lq , :eig , :eigh , :gen_eig , :svd , :polar )
57- default_f_algorithm = Symbol (:default_ , f, :_algorithm )
58- @eval MAK.$ default_f_algorithm (:: Type{<:BlockBlasMat{T}} ; kwargs... ) where {T} =
59- MAK.$ default_f_algorithm (Matrix{T}; kwargs... )
60- end
95+ # disambiguations
96+ # for (f!, Alg) in (
97+ # (:lq_compact!, :LAPACK_HouseholderLQ), (:lq_full!, :LAPACK_HouseholderLQ), (:lq_null!, :LAPACK_HouseholderLQ),
98+ # (:lq_compact!, :LQViaTransposedQR), (:lq_full!, :LQViaTransposedQR), (:lq_null!, :LQViaTransposedQR),
99+ # (:qr_compact!, :LAPACK_HouseholderQR), (:qr_full!, :LAPACK_HouseholderQR), (:qr_null!, :LAPACK_HouseholderQR),
100+ # (:svd_compact!, :LAPACK_SVDAlgorithm), (:svd_full!, :LAPACK_SVDAlgorithm), (:svd_vals!, :LAPACK_SVDAlgorithm),
101+ # (:eig_full!, :LAPACK_EigAlgorithm), (:eig_trunc!, :TruncatedAlgorithm), (:eig_vals!, :LAPACK_EigAlgorithm),
102+ # (:eigh_full!, :LAPACK_EighAlgorithm), (:eigh_trunc!, :TruncatedAlgorithm), (:eigh_vals!, :LAPACK_EighAlgorithm),
103+ # (:left_polar!, :PolarViaSVD), (:right_polar!, :PolarViaSVD),
104+ # )
105+ # @eval MAK.$f!(t::BlockBlasMat, F, alg::MAK.$Alg) = $f!(Array(t), alg)
106+ # end
107+ #
108+ # const GPU_QRAlgorithm = Union{MAK.CUSOLVER_HouseholderQR, MAK.ROCSOLVER_HouseholderQR}
109+ # for f! in (:qr_compact!, :qr_full!, :qr_null!)
110+ # @eval MAK.$f!(t::BlockBlasMat, QR, alg::GPU_QRAlgorithm) = error()
111+ # end
112+ #
113+ # for (f!, Alg) in (
114+ # (:eigh_full!, :GPU_EighAlgorithm), (:eigh_vals!, :GPU_EighAlgorithm),
115+ # (:eig_full!, :GPU_EigAlgorithm), (:eig_vals!, :GPU_EigAlgorithm),
116+ # (:svd_full!, :GPU_SVDAlgorithm), (:svd_compact!, :GPU_SVDAlgorithm), (:svd_vals!, :GPU_SVDAlgorithm),
117+ # )
118+ # @eval MAK.$f!(t::BlockBlasMat, F, alg::MAK.$Alg) = error()
119+ # end
120+
121+
122+ # for f in (:qr, :lq, :eig, :eigh, :gen_eig, :svd, :polar)
123+ # default_f_algorithm = Symbol(:default_, f, :_algorithm)
124+ # @eval MAK.$default_f_algorithm(::Type{<:BlockBlasMat{T}}; kwargs...) where {T} =
125+ # MAK.$default_f_algorithm(Matrix{T}; kwargs...)
126+ # end
61127
62128# Make sure sparse blocktensormaps have dense outputs
63- function MAK. check_input (:: typeof (qr_full!), t:: AbstractBlockTensorMap , QR, :: AbstractAlgorithm )
64- Q, R = QR
65-
66- # type checks
67- @assert Q isa AbstractTensorMap
68- @assert R isa AbstractTensorMap
69-
70- # scalartype checks
71- @check_scalar Q t
72- @check_scalar R t
73-
74- # space checks
75- V_Q = ⊕ (fuse (codomain (t)))
76- @check_space (Q, codomain (t) ← V_Q)
77- @check_space (R, V_Q ← domain (t))
78-
79- return nothing
80- end
81- MAK. check_input (:: typeof (qr_full!), t:: AbstractBlockTensorMap , QR, :: DiagonalAlgorithm ) = error ()
82-
83129function MAK. initialize_output (:: typeof (qr_full!), t:: AbstractBlockTensorMap , :: AbstractAlgorithm )
84130 V_Q = ⊕ (fuse (codomain (t)))
85131 Q = dense_similar (t, codomain (t) ← V_Q)
@@ -99,26 +145,6 @@ function MAK.initialize_output(::typeof(qr_null!), t::AbstractBlockTensorMap, ::
99145 return N
100146end
101147
102- function MAK. check_input (:: typeof (lq_full!), t:: AbstractBlockTensorMap , LQ, :: AbstractAlgorithm )
103- L, Q = LQ
104-
105- # type checks
106- @assert L isa AbstractTensorMap
107- @assert Q isa AbstractTensorMap
108-
109- # scalartype checks
110- @check_scalar L t
111- @check_scalar Q t
112-
113- # space checks
114- V_Q = ⊕ (fuse (domain (t)))
115- @check_space (L, codomain (t) ← V_Q)
116- @check_space (Q, V_Q ← domain (t))
117-
118- return nothing
119- end
120- MAK. check_input (:: typeof (lq_full!), t:: AbstractBlockTensorMap , LQ, :: DiagonalAlgorithm ) = error ()
121-
122148function MAK. initialize_output (:: typeof (lq_full!), t:: AbstractBlockTensorMap , :: AbstractAlgorithm )
123149 V_Q = ⊕ (fuse (domain (t)))
124150 L = dense_similar (t, codomain (t) ← V_Q)
@@ -138,99 +164,18 @@ function MAK.initialize_output(::typeof(lq_null!), t::AbstractBlockTensorMap, ::
138164 return N
139165end
140166
141- function MAK. check_input (:: typeof (MAK. left_orth_polar!), t:: AbstractBlockTensorMap , WP, :: AbstractAlgorithm )
142- codomain (t) ≿ domain (t) ||
143- throw (ArgumentError (" Polar decomposition requires `codomain(t) ≿ domain(t)`" ))
144-
145- W, P = WP
146- @assert W isa AbstractTensorMap
147- @assert P isa AbstractTensorMap
148-
149- # scalartype checks
150- @check_scalar W t
151- @check_scalar P t
152-
153- # space checks
154- VW = ⊕ (fuse (domain (t)))
155- @check_space (W, codomain (t) ← VW)
156- @check_space (P, VW ← domain (t))
157-
158- return nothing
159- end
160167function MAK. initialize_output (:: typeof (left_polar!), t:: AbstractBlockTensorMap , :: AbstractAlgorithm )
161168 W = dense_similar (t, space (t))
162169 P = dense_similar (t, domain (t) ← domain (t))
163170 return W, P
164171end
165172
166- function MAK. check_input (:: typeof (MAK. right_orth_polar!), t:: AbstractBlockTensorMap , PWᴴ, :: AbstractAlgorithm )
167- codomain (t) ≾ domain (t) ||
168- throw (ArgumentError (" Polar decomposition requires `domain(t) ≿ codomain(t)`" ))
169-
170- P, Wᴴ = PWᴴ
171- @assert P isa AbstractTensorMap
172- @assert Wᴴ isa AbstractTensorMap
173-
174- # scalartype checks
175- @check_scalar P t
176- @check_scalar Wᴴ t
177-
178- # space checks
179- VW = ⊕ (fuse (codomain (t)))
180- @check_space (P, codomain (t) ← VW)
181- @check_space (Wᴴ, VW ← domain (t))
182-
183- return nothing
184- end
185173function MAK. initialize_output (:: typeof (right_polar!), t:: AbstractBlockTensorMap , :: AbstractAlgorithm )
186174 P = dense_similar (t, codomain (t) ← codomain (t))
187175 Wᴴ = dense_similar (t, space (t))
188176 return P, Wᴴ
189177end
190178
191- function MAK. initialize_output (:: typeof (left_null!), t:: AbstractBlockTensorMap , :: AbstractAlgorithm )
192- V_Q = infimum (fuse (codomain (t)), fuse (domain (t)))
193- V_N = ⊖ (fuse (codomain (t)), V_Q)
194- return dense_similar (t, codomain (t) ← V_N)
195- end
196- function MAK. initialize_output (:: typeof (right_null!), t:: AbstractBlockTensorMap , :: AbstractAlgorithm )
197- V_Q = infimum (fuse (codomain (t)), fuse (domain (t)))
198- V_N = ⊖ (fuse (domain (t)), V_Q)
199- return dense_similar (t, V_N ← domain (t))
200- end
201-
202- function MAK. check_input (:: typeof (eigh_full!), t:: AbstractBlockTensorMap , DV, :: AbstractAlgorithm )
203- domain (t) == codomain (t) ||
204- throw (ArgumentError (" Eigenvalue decomposition requires square input tensor" ))
205-
206- D, V = DV
207-
208- # type checks
209- @assert D isa DiagonalTensorMap
210- @assert V isa AbstractTensorMap
211-
212- # scalartype checks
213- @check_scalar D t real
214- @check_scalar V t
215-
216- # space checks
217- V_D = ⊕ (fuse (domain (t)))
218- @check_space (D, V_D ← V_D)
219- @check_space (V, codomain (t) ← V_D)
220-
221- return nothing
222- end
223- MAK. check_input (:: typeof (eigh_full!), t:: AbstractBlockTensorMap , DV, :: DiagonalAlgorithm ) = error ()
224-
225- function MAK. check_input (:: typeof (eigh_vals!), t:: AbstractBlockTensorMap , D, :: AbstractAlgorithm )
226- @check_scalar D t real
227- @assert D isa DiagonalTensorMap
228- V_D = ⊕ (fuse (domain (t)))
229- @check_space (D, V_D ← V_D)
230- return nothing
231- end
232- MAK. check_input (:: typeof (eigh_vals!), t:: AbstractBlockTensorMap , D, :: DiagonalAlgorithm ) = error ()
233-
234179function MAK. initialize_output (:: typeof (eigh_full!), t:: AbstractBlockTensorMap , :: AbstractAlgorithm )
235180 V_D = ⊕ (fuse (domain (t)))
236181 T = real (scalartype (t))
@@ -239,30 +184,6 @@ function MAK.initialize_output(::typeof(eigh_full!), t::AbstractBlockTensorMap,
239184 return D, V
240185end
241186
242-
243- function MAK. check_input (:: typeof (eig_full!), t:: AbstractBlockTensorMap , DV, :: AbstractAlgorithm )
244- domain (t) == codomain (t) ||
245- throw (ArgumentError (" Eigenvalue decomposition requires square input tensor" ))
246-
247- D, V = DV
248-
249- # type checks
250- @assert D isa DiagonalTensorMap
251- @assert V isa AbstractTensorMap
252-
253- # scalartype checks
254- @check_scalar D t complex
255- @check_scalar V t complex
256-
257- # space checks
258- V_D = ⊕ (fuse (domain (t)))
259- @check_space (D, V_D ← V_D)
260- @check_space (V, codomain (t) ← V_D)
261-
262- return nothing
263- end
264- MAK. check_input (:: typeof (eig_full!), t:: AbstractBlockTensorMap , DV, :: DiagonalAlgorithm ) = error ()
265-
266187function MAK. initialize_output (:: typeof (eig_full!), t:: AbstractBlockTensorMap , :: AbstractAlgorithm )
267188 V_D = ⊕ (fuse (domain (t)))
268189 Tc = complex (scalartype (t))
@@ -271,28 +192,6 @@ function MAK.initialize_output(::typeof(eig_full!), t::AbstractBlockTensorMap, :
271192 return D, V
272193end
273194
274- function MAK. check_input (:: typeof (svd_full!), t:: AbstractBlockTensorMap , USVᴴ, :: AbstractAlgorithm )
275- U, S, Vᴴ = USVᴴ
276-
277- # type checks
278- @assert U isa AbstractTensorMap
279- @assert S isa AbstractTensorMap
280- @assert Vᴴ isa AbstractTensorMap
281-
282- # scalartype checks
283- @check_scalar U t
284- @check_scalar S t real
285- @check_scalar Vᴴ t
286-
287- # space checks
288- V_cod = ⊕ (fuse (codomain (t)))
289- V_dom = ⊕ (fuse (domain (t)))
290- @check_space (U, codomain (t) ← V_cod)
291- @check_space (S, V_cod ← V_dom)
292- @check_space (Vᴴ, V_dom ← domain (t))
293-
294- return nothing
295- end
296195function MAK. initialize_output (:: typeof (svd_full!), t:: AbstractBlockTensorMap , :: AbstractAlgorithm )
297196 V_cod = ⊕ (fuse (codomain (t)))
298197 V_dom = ⊕ (fuse (domain (t)))
0 commit comments