Skip to content

Commit d3d6a24

Browse files
committed
orthnull progress
1 parent 94f9048 commit d3d6a24

5 files changed

Lines changed: 84 additions & 45 deletions

File tree

src/TensorKit.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ export left_orth, right_orth, left_null, right_null,
8282
eigh_full!, eigh_full, eigh_trunc!, eigh_trunc, eig_full!, eig_full, eig_trunc!,
8383
eig_trunc,
8484
eigh_vals!, eigh_vals, eig_vals!, eig_vals,
85-
isposdef, isposdef!, ishermitian, isisometric, isunitary, sylvester, rank, cond
85+
isposdef, isposdef!, ishermitian, isisometric, isunitary, sylvester, rank, cond,
86+
LeftOrthAlgorithm, RightOrthAlgorithm
8687

8788
export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition,
8889
repartition!

src/factorizations/factorizations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using TensorOperations: Index2Tuple
1616
using MatrixAlgebraKit
1717
import MatrixAlgebraKit as MAK
1818
using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, DiagonalAlgorithm
19+
using MatrixAlgebraKit: LeftOrthAlgorithm, RightOrthAlgorithm
1920
using MatrixAlgebraKit: TruncationStrategy, NoTruncation, TruncationByValue,
2021
TruncationByError, TruncationIntersection, TruncationByFilter, TruncationByOrder
2122
using MatrixAlgebraKit: diagview, isisometric

src/factorizations/matrixalgebrakit.jl

Lines changed: 66 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ end
2626
for f! in (
2727
:qr_compact!, :qr_full!, :lq_compact!, :lq_full!,
2828
:eig_full!, :eigh_full!, :svd_compact!, :svd_full!,
29-
:left_polar!, :right_polar!, :left_orth!, :right_orth!,
29+
:left_polar!, :right_polar!
3030
)
3131
@eval function MAK.$f!(t::AbstractTensorMap, F, alg::AbstractAlgorithm)
3232
MAK.check_input($f!, t, F, alg)
@@ -45,6 +45,43 @@ for f! in (
4545
end
4646
end
4747

48+
for alg in (MAK.LeftOrthAlgorithm{:qr}, MAK.LeftOrthAlgorithm{:svd}, MAK.LeftOrthAlgorithm{:polar})
49+
@eval begin
50+
function MAK.left_orth!(t::AbstractTensorMap, F, alg::$alg)
51+
MAK.check_input(left_orth!, t, F, alg)
52+
53+
foreachblock(t, F...) do _, bs
54+
factors = Base.tail(bs)
55+
factors′ = left_orth!(first(bs), factors, alg)
56+
# deal with the case where the output is not in-place
57+
for (f′, f) in zip(factors′, factors)
58+
f′ === f || copy!(f, f′)
59+
end
60+
return nothing
61+
end
62+
return F
63+
end
64+
end
65+
end
66+
67+
for alg in (MAK.RightOrthAlgorithm{:lq}, MAK.RightOrthAlgorithm{:svd}, MAK.RightOrthAlgorithm{:polar})
68+
@eval function MAK.right_orth!(t::AbstractTensorMap, F, alg::$alg)
69+
MAK.check_input(right_orth!, t, F, alg)
70+
71+
foreachblock(t, F...) do _, bs
72+
factors = Base.tail(bs)
73+
factors′ = right_orth!(first(bs), factors, alg)
74+
# deal with the case where the output is not in-place
75+
for (f′, f) in zip(factors′, factors)
76+
f′ === f || copy!(f, f′)
77+
end
78+
return nothing
79+
end
80+
return F
81+
end
82+
end
83+
84+
4885
# Handle these separately because single output instead of tuple
4986
for f! in (:qr_null!, :lq_null!)
5087
@eval function MAK.$f!(t::AbstractTensorMap, N, alg::AbstractAlgorithm)
@@ -464,6 +501,18 @@ function MAK.check_input(::typeof(left_orth!), t::AbstractTensorMap, VC, ::Abstr
464501
return nothing
465502
end
466503

504+
function MAK.check_input(::typeof(left_orth!), t::AbstractTensorMap, VC, alg::MAK.LeftOrthAlgorithm{:qr})
505+
return MAK.check_input(qr_compact!, t, VC, alg.alg)
506+
end
507+
508+
function MAK.check_input(::typeof(left_orth!), t::AbstractTensorMap, VC, alg::MAK.LeftOrthAlgorithm{:svd})
509+
return MAK.check_input(svd_compact!, t, VC, alg.alg)
510+
end
511+
512+
function MAK.check_input(::typeof(left_orth!), t::AbstractTensorMap, VC, alg::MAK.LeftOrthAlgorithm{:polar})
513+
return MAK.check_input(left_polar!, t, VC, alg.alg)
514+
end
515+
467516
function MAK.check_input(::typeof(right_orth!), t::AbstractTensorMap, CVᴴ, ::AbstractAlgorithm)
468517
C, Vᴴ = CVᴴ
469518

@@ -479,6 +528,18 @@ function MAK.check_input(::typeof(right_orth!), t::AbstractTensorMap, CVᴴ, ::A
479528
return nothing
480529
end
481530

531+
function MAK.check_input(::typeof(right_orth!), t::AbstractTensorMap, VC, alg::MAK.RightOrthAlgorithm{:lq})
532+
return MAK.check_input(lq_compact!, t, VC, alg.alg)
533+
end
534+
535+
function MAK.check_input(::typeof(right_orth!), t::AbstractTensorMap, VC, alg::MAK.RightOrthAlgorithm{:svd})
536+
return MAK.check_input(svd_compact!, t, VC, alg.alg)
537+
end
538+
539+
function MAK.check_input(::typeof(right_orth!), t::AbstractTensorMap, VC, alg::MAK.RightOrthAlgorithm{:polar})
540+
return MAK.check_input(right_polar!, t, VC, alg.alg)
541+
end
542+
482543
function MAK.initialize_output(::typeof(left_orth!), t::AbstractTensorMap)
483544
V_C = infimum(fuse(codomain(t)), fuse(domain(t)))
484545
V = similar(t, codomain(t) V_C)
@@ -501,44 +562,16 @@ end
501562
function MAK.left_orth!(
502563
t::AbstractTensorMap;
503564
trunc::TruncationStrategy = notrunc(),
504-
kind = trunc == notrunc() ? :qr : :svd,
505-
alg_qr = (; positive = true), alg_polar = (;), alg_svd = (;)
565+
alg::AbstractAlgorithm = (trunc == notrunc()) ? MAK.select_algorithm(left_orth!, t, Val(:qr)) : MAK.select_algorithm(left_orth!, t, Val(:svd); trunc)
506566
)
507-
trunc == notrunc() || kind === :svd ||
508-
throw(ArgumentError("truncation not supported for left_orth with kind = $kind"))
509-
510-
return if kind === :qr
511-
alg_qr isa NamedTuple ? qr_compact!(t; alg_qr...) : qr_compact!(t; alg = alg_qr)
512-
elseif kind === :polar
513-
alg_polar isa NamedTuple ? left_orth_polar!(t; alg_polar...) :
514-
left_orth_polar!(t; alg = alg_polar)
515-
elseif kind === :svd
516-
alg_svd isa NamedTuple ? left_orth_svd!(t; trunc, alg_svd...) :
517-
left_orth_svd!(t; trunc, alg = alg_svd)
518-
else
519-
throw(ArgumentError(lazy"`left_orth!` received unknown value `kind = $kind`"))
520-
end
567+
MAK.left_orth!(t, MAK.initialize_output(left_orth!, t, alg), alg)
521568
end
522569
function MAK.right_orth!(
523570
t::AbstractTensorMap;
524571
trunc::TruncationStrategy = notrunc(),
525-
kind = trunc == notrunc() ? :lq : :svd,
526-
alg_lq = (; positive = true), alg_polar = (;), alg_svd = (;)
572+
alg::AbstractAlgorithm = (trunc == notrunc()) ? MAK.select_algorithm(right_orth!, t, Val(:lq)) : MAK.select_algorithm(right_orth!, t, Val(:svd); trunc)
527573
)
528-
trunc == notrunc() || kind === :svd ||
529-
throw(ArgumentError("truncation not supported for right_orth with kind = $kind"))
530-
531-
return if kind === :lq
532-
alg_lq isa NamedTuple ? lq_compact!(t; alg_lq...) : lq_compact!(t; alg = alg_lq)
533-
elseif kind === :polar
534-
alg_polar isa NamedTuple ? right_orth_polar!(t; alg_polar...) :
535-
right_orth_polar!(t; alg = alg_polar)
536-
elseif kind === :svd
537-
alg_svd isa NamedTuple ? right_orth_svd!(t; trunc, alg_svd...) :
538-
right_orth_svd!(t; trunc, alg = alg_svd)
539-
else
540-
throw(ArgumentError(lazy"`right_orth!` received unknown value `kind = $kind`"))
541-
end
574+
MAK.right_orth!(t, MAK.initialize_output(right_orth!, t, alg), alg)
542575
end
543576

544577
# Nullspace

test/autodiff/ad.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ function remove_lqgauge_dependence!(ΔQ, t, Q)
103103
return ΔQ
104104
end
105105
function remove_eiggauge_dependence!(
106-
ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(D)
106+
ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D)
107107
)
108108
gaugepart = V' * ΔV
109109
for (c, b) in blocks(gaugepart)
@@ -119,7 +119,7 @@ function remove_eiggauge_dependence!(
119119
return ΔV
120120
end
121121
function remove_eighgauge_dependence!(
122-
ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(D)
122+
ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D)
123123
)
124124
gaugepart = V' * ΔV
125125
gaugepart = (gaugepart - gaugepart') / 2
@@ -136,7 +136,7 @@ function remove_eighgauge_dependence!(
136136
return ΔV
137137
end
138138
function remove_svdgauge_dependence!(
139-
ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
139+
ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S)
140140
)
141141
gaugepart = U' * ΔU + Vᴴ * ΔVᴴ'
142142
gaugepart = (gaugepart - gaugepart') / 2

test/tensors/factorizations.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,15 @@ for V in spacelist
4949
@test Q * R t
5050
@test isisometric(Q)
5151

52-
Q, R = @constinferred left_orth(t; kind = :qr)
52+
Q, R = @constinferred left_orth(t)
5353
@test Q * R t
5454
@test isisometric(Q)
5555

5656
N = @constinferred qr_null(t)
5757
@test isisometric(N)
5858
@test norm(N' * t) 0 atol = 100 * eps(norm(t))
5959

60-
N = @constinferred left_null(t; kind = :qr)
60+
N = @constinferred left_null(t)
6161
@test isisometric(N)
6262
@test norm(N' * t) 0 atol = 100 * eps(norm(t))
6363
end
@@ -76,7 +76,7 @@ for V in spacelist
7676
@test isisometric(Q)
7777
@test dim(Q) == dim(R) == dim(t)
7878

79-
Q, R = @constinferred left_orth(t; kind = :qr)
79+
Q, R = @constinferred left_orth(t)
8080
@test Q * R t
8181
@test isisometric(Q)
8282
@test dim(Q) == dim(R) == dim(t)
@@ -102,7 +102,7 @@ for V in spacelist
102102
@test L * Q t
103103
@test isisometric(Q; side = :right)
104104

105-
L, Q = @constinferred right_orth(t; kind = :lq)
105+
L, Q = @constinferred right_orth(t)
106106
@test L * Q t
107107
@test isisometric(Q; side = :right)
108108

@@ -125,7 +125,7 @@ for V in spacelist
125125
@test isisometric(Q; side = :right)
126126
@test dim(Q) == dim(L) == dim(t)
127127

128-
L, Q = @constinferred right_orth(t; kind = :lq)
128+
L, Q = @constinferred right_orth(t)
129129
@test L * Q t
130130
@test isisometric(Q; side = :right)
131131
@test dim(Q) == dim(L) == dim(t)
@@ -149,7 +149,7 @@ for V in spacelist
149149
@test isisometric(w)
150150
@test isposdef(p)
151151

152-
w, p = @constinferred left_orth(t; kind = :polar)
152+
w, p = @constinferred left_orth(t; alg = TensorKit.LeftOrthAlgorithm{:polar})
153153
@test w * p t
154154
@test isisometric(w)
155155
end
@@ -163,7 +163,7 @@ for V in spacelist
163163
@test isisometric(wᴴ; side = :right)
164164
@test isposdef(p)
165165

166-
p, wᴴ = @constinferred right_orth(t; kind = :polar)
166+
p, wᴴ = @constinferred right_orth(t; alg = TensorKit.RightOrthAlgorithm{:polar})
167167
@test p * wᴴ t
168168
@test isisometric(wᴴ; side = :right)
169169
end
@@ -194,9 +194,13 @@ for V in spacelist
194194
@test b s′[c]
195195
end
196196

197-
v, c = @constinferred left_orth(t; kind = :svd)
197+
v, c = @constinferred left_orth(t; alg = TensorKit.LeftOrthAlgorithm{:svd})
198198
@test v * c t
199199
@test isisometric(v)
200+
201+
c, vᴴ = @constinferred right_orth(t; alg = TensorKit.RightOrthAlgorithm{:svd})
202+
@test c * vᴴ t
203+
@test isisometric(v; side = :right)
200204

201205
N = @constinferred left_null(t; kind = :svd)
202206
@test isisometric(N)

0 commit comments

Comments
 (0)