Skip to content

Commit 96cdad8

Browse files
authored
Implement DefaultAlgorithm support (#422)
* add `DefaultAlgorithm` overloads * add DefaultAlgorithm tests * disambiguate
1 parent e2b0655 commit 96cdad8

6 files changed

Lines changed: 344 additions & 5 deletions

File tree

src/factorizations/matrixalgebrakit.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,38 @@ end
6464

6565
MAK.zero!(t::AbstractTensorMap) = zerovector!(t)
6666

67+
# Default algorithm
68+
# -----------------
69+
for f in [
70+
:lq_full, :lq_compact, :lq_null,
71+
:qr_full, :qr_compact, :qr_null,
72+
:schur_full, :schur_vals,
73+
:eig_full, :eig_vals, :eig_trunc, :eig_trunc_no_error,
74+
:eigh_full, :eigh_vals, :eigh_trunc, :eigh_trunc_no_error,
75+
:svd_full, :svd_compact, :svd_trunc, :svd_trunc_no_error, :svd_vals,
76+
:left_polar, :right_polar,
77+
:left_orth, :right_orth, :left_null, :right_null,
78+
:project_hermitian, :project_antihermitian, :project_isometric,
79+
]
80+
f! = Symbol(f, :!)
81+
@eval MAK.$f!(t::AbstractTensorMap, alg::DefaultAlgorithm) =
82+
MAK.$f!(t, MAK.select_algorithm(MAK.$f!, t, nothing; alg.kwargs...))
83+
@eval MAK.$f!(t::AbstractTensorMap, out, alg::DefaultAlgorithm) =
84+
MAK.$f!(t, out, MAK.select_algorithm(MAK.$f!, t, nothing; alg.kwargs...))
85+
86+
# disambiguate
87+
@eval MAK.$f!(t::AdjointTensorMap, alg::DefaultAlgorithm) =
88+
MAK.$f!(t, MAK.select_algorithm(MAK.$f!, t, nothing; alg.kwargs...))
89+
@eval MAK.$f!(t::AdjointTensorMap, out, alg::DefaultAlgorithm) =
90+
MAK.$f!(t, out, MAK.select_algorithm(MAK.$f!, t, nothing; alg.kwargs...))
91+
92+
@eval MAK.$f!(t::DiagonalTensorMap, alg::DefaultAlgorithm) =
93+
MAK.$f!(t, MAK.select_algorithm(MAK.$f!, t, nothing; alg.kwargs...))
94+
@eval MAK.$f!(t::DiagonalTensorMap, out, alg::DefaultAlgorithm) =
95+
MAK.$f!(t, out, MAK.select_algorithm(MAK.$f!, t, nothing; alg.kwargs...))
96+
end
97+
98+
6799
# Singular value decomposition
68100
# ----------------------------
69101
function MAK.initialize_output(::typeof(svd_full!), t::AbstractTensorMap, ::AbstractAlgorithm)

test/cuda/factorizations.jl

Lines changed: 153 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using Adapt, CUDA, CUDA.cuRAND, cuTENSOR
22
using Test, TestExtras
33
using TensorKit
44
using LinearAlgebra: LinearAlgebra
5-
using MatrixAlgebraKit: diagview
5+
using MatrixAlgebraKit: DefaultAlgorithm, diagview
66
const CUDAExt = Base.get_extension(TensorKit, :TensorKitCUDAExt)
77
@assert !isnothing(CUDAExt) "Failed to load TensorKit - CUDA extension"
88
const CuTensorMap = getglobal(CUDAExt, :CuTensorMap)
@@ -35,21 +35,41 @@ for V in spacelist
3535
@test Q * R t
3636
@test isunitary(Q)
3737

38+
Q, R = @constinferred qr_full(t, DefaultAlgorithm())
39+
@test Q * R t
40+
@test isunitary(Q)
41+
3842
Q, R = @constinferred qr_compact(t)
3943
@test Q * R t
4044
@test isisometric(Q)
4145

46+
Q, R = @constinferred qr_compact(t, DefaultAlgorithm())
47+
@test Q * R t
48+
@test isisometric(Q)
49+
4250
Q, R = @constinferred left_orth(t)
4351
@test Q * R t
4452
@test isisometric(Q)
4553

54+
Q, R = @constinferred left_orth(t, DefaultAlgorithm())
55+
@test Q * R t
56+
@test isisometric(Q)
57+
4658
N = @constinferred qr_null(t)
4759
@test isisometric(N)
4860
@test norm(N' * t) 0 atol = 100 * eps(norm(t))
4961

62+
N = @constinferred qr_null(t, DefaultAlgorithm())
63+
@test isisometric(N)
64+
@test norm(N' * t) 0 atol = 100 * eps(norm(t))
65+
5066
N = @constinferred left_null(t)
5167
@test isisometric(N)
5268
@test norm(N' * t) 0 atol = 100 * eps(norm(t))
69+
70+
N = @constinferred left_null(t, DefaultAlgorithm())
71+
@test isisometric(N)
72+
@test norm(N' * t) 0 atol = 100 * eps(norm(t))
5373
end
5474

5575
# empty tensor
@@ -61,19 +81,38 @@ for V in spacelist
6181
@test isunitary(Q)
6282
@test dim(R) == dim(t) == 0
6383

84+
Q, R = @constinferred qr_full(t, DefaultAlgorithm())
85+
@test Q * R t
86+
@test isunitary(Q)
87+
@test dim(R) == dim(t) == 0
88+
6489
Q, R = @constinferred qr_compact(t)
6590
@test Q * R t
6691
@test isisometric(Q)
6792
@test dim(Q) == dim(R) == dim(t)
6893

94+
Q, R = @constinferred qr_compact(t, DefaultAlgorithm())
95+
@test Q * R t
96+
@test isisometric(Q)
97+
@test dim(Q) == dim(R) == dim(t)
98+
6999
Q, R = @constinferred left_orth(t)
70100
@test Q * R t
71101
@test isisometric(Q)
72102
@test dim(Q) == dim(R) == dim(t)
73103

104+
Q, R = @constinferred left_orth(t, DefaultAlgorithm())
105+
@test Q * R t
106+
@test isisometric(Q)
107+
@test dim(Q) == dim(R) == dim(t)
108+
74109
N = @constinferred qr_null(t)
75110
@test isunitary(N)
76111
@test norm(N' * t) 0 atol = 100 * eps(norm(t))
112+
113+
N = @constinferred qr_null(t, DefaultAlgorithm())
114+
@test isunitary(N)
115+
@test norm(N' * t) 0 atol = 100 * eps(norm(t))
77116
end
78117
end
79118

@@ -90,17 +129,33 @@ for V in spacelist
90129
@test L * Q t
91130
@test isunitary(Q)
92131

132+
L, Q = @constinferred lq_full(t, DefaultAlgorithm())
133+
@test L * Q t
134+
@test isunitary(Q)
135+
93136
L, Q = @constinferred lq_compact(t)
94137
@test L * Q t
95138
@test isisometric(Q; side = :right)
96139

140+
L, Q = @constinferred lq_compact(t, DefaultAlgorithm())
141+
@test L * Q t
142+
@test isisometric(Q; side = :right)
143+
97144
L, Q = @constinferred right_orth(t)
98145
@test L * Q t
99146
@test isisometric(Q; side = :right)
100147

148+
L, Q = @constinferred right_orth(t, DefaultAlgorithm())
149+
@test L * Q t
150+
@test isisometric(Q; side = :right)
151+
101152
Nᴴ = @constinferred lq_null(t)
102153
@test isisometric(Nᴴ; side = :right)
103154
@test norm(t * Nᴴ') 0 atol = 100 * eps(norm(t))
155+
156+
Nᴴ = @constinferred lq_null(t, DefaultAlgorithm())
157+
@test isisometric(Nᴴ; side = :right)
158+
@test norm(t * Nᴴ') 0 atol = 100 * eps(norm(t))
104159
end
105160

106161
for T in eltypes
@@ -112,19 +167,38 @@ for V in spacelist
112167
@test isunitary(Q)
113168
@test dim(L) == dim(t) == 0
114169

170+
L, Q = @constinferred lq_full(t, DefaultAlgorithm())
171+
@test L * Q t
172+
@test isunitary(Q)
173+
@test dim(L) == dim(t) == 0
174+
115175
L, Q = @constinferred lq_compact(t)
116176
@test L * Q t
117177
@test isisometric(Q; side = :right)
118178
@test dim(Q) == dim(L) == dim(t)
119179

180+
L, Q = @constinferred lq_compact(t, DefaultAlgorithm())
181+
@test L * Q t
182+
@test isisometric(Q; side = :right)
183+
@test dim(Q) == dim(L) == dim(t)
184+
120185
L, Q = @constinferred right_orth(t)
121186
@test L * Q t
122187
@test isisometric(Q; side = :right)
123188
@test dim(Q) == dim(L) == dim(t)
124189

190+
L, Q = @constinferred right_orth(t, DefaultAlgorithm())
191+
@test L * Q t
192+
@test isisometric(Q; side = :right)
193+
@test dim(Q) == dim(L) == dim(t)
194+
125195
Nᴴ = @constinferred lq_null(t)
126196
@test isunitary(Nᴴ)
127197
@test norm(t * Nᴴ') 0 atol = 100 * eps(norm(t))
198+
199+
Nᴴ = @constinferred lq_null(t, DefaultAlgorithm())
200+
@test isunitary(Nᴴ)
201+
@test norm(t * Nᴴ') 0 atol = 100 * eps(norm(t))
128202
end
129203
end
130204

@@ -143,6 +217,11 @@ for V in spacelist
143217
@test isisometric(w)
144218
@test isposdef(p)
145219

220+
w, p = @constinferred left_polar(t, DefaultAlgorithm())
221+
@test w * p t
222+
@test isisometric(w)
223+
@test isposdef(p)
224+
146225
w, p = @constinferred left_orth(t; alg = :polar)
147226
@test w * p t
148227
@test isisometric(w)
@@ -162,6 +241,11 @@ for V in spacelist
162241
@test isisometric(wᴴ; side = :right)
163242
@test isposdef(p)
164243

244+
p, wᴴ = @constinferred right_polar(t, DefaultAlgorithm())
245+
@test p * wᴴ t
246+
@test isisometric(wᴴ; side = :right)
247+
@test isposdef(p)
248+
165249
p, wᴴ = @constinferred right_orth(t; alg = :polar)
166250
@test p * wᴴ t
167251
@test isisometric(wᴴ; side = :right)
@@ -182,16 +266,31 @@ for V in spacelist
182266
@test isunitary(u)
183267
@test isunitary(vᴴ)
184268

269+
u, s, vᴴ = @constinferred svd_full(t, DefaultAlgorithm())
270+
@test u * s * vᴴ t
271+
@test isunitary(u)
272+
@test isunitary(vᴴ)
273+
185274
u, s, vᴴ = @constinferred svd_compact(t)
186275
@test u * s * vᴴ t
187276
@test isisometric(u)
188277
@test isposdef(s)
189278
@test isisometric(vᴴ; side = :right)
190279

280+
u, s, vᴴ = @constinferred svd_compact(t, DefaultAlgorithm())
281+
@test u * s * vᴴ t
282+
@test isisometric(u)
283+
@test isposdef(s)
284+
@test isisometric(vᴴ; side = :right)
285+
191286
s′ = @constinferred svd_vals(t)
192287
@test parent(s′) parent(diagview(s))
193288
@test s′ isa TensorKit.SectorVector
194289

290+
s′ = @constinferred svd_vals(t, DefaultAlgorithm())
291+
@test parent(s′) parent(diagview(s))
292+
@test s′ isa TensorKit.SectorVector
293+
195294
s2 = @constinferred DiagonalTensorMap(s′)
196295
@test s2 s
197296

@@ -230,9 +329,18 @@ for V in spacelist
230329
@test isunitary(U)
231330
@test isunitary(Vᴴ)
232331

332+
U, S, Vᴴ = @constinferred svd_full(t, DefaultAlgorithm())
333+
@test U * S * Vᴴ t
334+
@test isunitary(U)
335+
@test isunitary(Vᴴ)
336+
233337
U, S, Vᴴ = @constinferred svd_compact(t)
234338
@test U * S * Vᴴ t
235339
@test dim(U) == dim(S) == dim(Vᴴ) == dim(t) == 0
340+
341+
U, S, Vᴴ = @constinferred svd_compact(t, DefaultAlgorithm())
342+
@test U * S * Vᴴ t
343+
@test dim(U) == dim(S) == dim(Vᴴ) == dim(t) == 0
236344
end
237345
end
238346

@@ -253,6 +361,12 @@ for V in spacelist
253361
@test isisometric(U)
254362
@test isisometric(Vᴴ; side = :right)
255363

364+
U, S, Vᴴ, ϵ = @constinferred svd_trunc(t, DefaultAlgorithm(; trunc = notrunc()))
365+
@test U * S * Vᴴ t
366+
@test ϵ 0
367+
@test isisometric(U)
368+
@test isisometric(Vᴴ; side = :right)
369+
256370
# dimension of S is a float for IsingBimodule
257371
nvals = round(Int, dim(domain(S)) / 2)
258372
trunc = truncrank(nvals)
@@ -316,10 +430,17 @@ for V in spacelist
316430
d, v = @constinferred eig_full(t)
317431
@test t * v v * d
318432

433+
d, v = @constinferred eig_full(t, DefaultAlgorithm())
434+
@test t * v v * d
435+
319436
d′ = @constinferred eig_vals(t)
320437
@test parent(d′) parent(diagview(d))
321438
@test d′ isa TensorKit.SectorVector
322439

440+
d′ = @constinferred eig_vals(t, DefaultAlgorithm())
441+
@test parent(d′) parent(diagview(d))
442+
@test d′ isa TensorKit.SectorVector
443+
323444
d2 = @constinferred DiagonalTensorMap(d′)
324445
@test d2 d
325446

@@ -332,12 +453,21 @@ for V in spacelist
332453
@test t * v v * d
333454
#test_dim_isapprox(domain(d), nvals)
334455

456+
d, v = @constinferred eig_trunc(t, DefaultAlgorithm(; trunc = truncrank(nvals)))
457+
@test t * v v * d
458+
#test_dim_isapprox(domain(d), nvals)
459+
335460
t2 = @constinferred project_hermitian(t)
336461
D, V = eigen(t2)
337462
@test isisometric(V)
338463
D̃, Ṽ = @constinferred eigh_full(t2)
339464
@test D
340465
@test V
466+
467+
D̃, Ṽ = @constinferred eigh_full(t2, DefaultAlgorithm())
468+
@test D
469+
@test V
470+
341471
λ = minimum(real, parent(diagview(D)))
342472
@test cond(Ṽ) one(real(T))
343473
@test isposdef(t2) == isposdef(λ)
@@ -352,6 +482,10 @@ for V in spacelist
352482
@test parent(d′) parent(diagview(d))
353483
@test d′ isa TensorKit.SectorVector
354484

485+
d′ = @constinferred eigh_vals(t2, DefaultAlgorithm())
486+
@test parent(d′) parent(diagview(d))
487+
@test d′ isa TensorKit.SectorVector
488+
355489
λ = minimum(real, parent(diagview(d)))
356490
@test cond(v) one(real(T))
357491
@test isposdef(t2) == isposdef(λ)
@@ -361,6 +495,10 @@ for V in spacelist
361495
d, v = @constinferred eigh_trunc(t2; trunc = truncrank(nvals))
362496
@test t2 * v v * d
363497
#test_dim_isapprox(domain(d), nvals)
498+
499+
d, v = @constinferred eigh_trunc(t2, DefaultAlgorithm(; trunc = truncrank(nvals)))
500+
@test t2 * v v * d
501+
#test_dim_isapprox(domain(d), nvals)
364502
end
365503
end
366504

@@ -423,6 +561,11 @@ for V in spacelist
423561
th′ = @constinferred project_hermitian(t)
424562
@test ishermitian(th′)
425563
@test th′ th
564+
565+
th′ = @constinferred project_hermitian(t, DefaultAlgorithm())
566+
@test ishermitian(th′)
567+
@test th′ th
568+
426569
@test t == tc
427570
th_approx = th + noisefactor * ta
428571
@test !ishermitian(th_approx) || (T <: Real && t isa DiagonalTensorMap)
@@ -431,6 +574,11 @@ for V in spacelist
431574
ta′ = project_antihermitian(t)
432575
@test isantihermitian(ta′)
433576
@test ta′ ta
577+
578+
ta′ = @constinferred project_antihermitian(t, DefaultAlgorithm())
579+
@test isantihermitian(ta′)
580+
@test ta′ ta
581+
434582
@test t == tc
435583
ta_approx = ta + noisefactor * th
436584
@test !isantihermitian(ta_approx)
@@ -448,6 +596,10 @@ for V in spacelist
448596
)
449597
t2 = project_isometric(t)
450598
@test isisometric(t2)
599+
t2′ = @constinferred project_isometric(t, DefaultAlgorithm())
600+
@test isisometric(t2′)
601+
@test t2′ * ((t2′)' * t) t
602+
451603
t3 = project_isometric(t2)
452604
@test t3 t2 # stability of the projection
453605
@test t2 * (t2' * t) t

0 commit comments

Comments
 (0)