Skip to content

Commit ecb67a0

Browse files
committed
whole lotta change
1 parent 3d638c2 commit ecb67a0

19 files changed

Lines changed: 279 additions & 623 deletions

src/fusiontrees/braiding_manipulations.jl

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ function artin_braid(f::FusionTree{I, N}, i; inv::Bool = false) where {I, N}
6464
a = inner_extended[i - 1]
6565
c = inner_extended[i]
6666
e = inner_extended[i + 1]
67-
c′ = first(a d)
67+
c′ = only(a d)
6868
coeff = oftype(
6969
oneT,
7070
if inv
@@ -122,6 +122,8 @@ function artin_braid(src::FusionTreeBlock{I, N, 0}, i; inv::Bool = false) where
122122
BraidingStyle(I) isa NoBraiding &&
123123
throw(SectorMismatch(lazy"Cannot braid sectors $a and $b"))
124124

125+
T = typeof(oneT)
126+
localbraidcache = Dict{NTuple{6, I}, FusionStyle(I) isa MultiplicityFreeFusion ? T : Array{T, 4}}()
125127
for (col, (f, f₂)) in enumerate(fusiontrees(src))
126128
inner = f.innerlines
127129
inner_extended = (uncoupled[1], inner..., coupled′)
@@ -156,10 +158,8 @@ function artin_braid(src::FusionTreeBlock{I, N, 0}, i; inv::Bool = false) where
156158
e = inner_extended[i + 1]
157159
if FusionStyle(I) isa MultiplicityFreeFusion
158160
for c′ in intersect(a d, e conj(b))
159-
coeff = if inv
160-
conj(Rsymbol(d, c, e) * Fsymbol(d, a, b, e, c′, c)) * Rsymbol(d, a, c′)
161-
else
162-
Rsymbol(c, d, e) * conj(Fsymbol(d, a, b, e, c′, c) * Rsymbol(a, d, c′))
161+
coeff = let k = (a, b, c, d, e, c′)
162+
get!(() -> _artin_braid_local(k, inv), localbraidcache, k)
163163
end
164164
iszero(coeff) && continue
165165
inner′ = TupleTools.setindex(inner, c′, i - 1)
@@ -172,15 +172,14 @@ function artin_braid(src::FusionTreeBlock{I, N, 0}, i; inv::Bool = false) where
172172
Rmat1 = inv ? Rsymbol(d, c, e)' : Rsymbol(c, d, e)
173173
Rmat2 = inv ? Rsymbol(d, a, c′)' : Rsymbol(a, d, c′)
174174
Fmat = Fsymbol(d, a, b, e, c′, c)
175+
coeff_tensor = let k = (a, b, c, d, e, c′)
176+
get!(() -> _artin_braid_local(k, inv), localbraidcache, k)
177+
end
175178
μ = vertices[i - 1]
176179
ν = vertices[i]
177-
for σ in 1:Nsymbol(a, d, c′)
178-
for λ in 1:Nsymbol(c′, b, e)
179-
coeff = zero(oneT)
180-
for ρ in 1:Nsymbol(d, c, e), κ in 1:Nsymbol(d, a, c′)
181-
coeff += Rmat1[ν, ρ] * conj(Fmat[κ, λ, μ, ρ]) *
182-
conj(Rmat2[σ, κ])
183-
end
180+
for λ in 1:size(coeff_tensor, 2)
181+
for σ in 1:size(coeff_tensor, 1)
182+
coeff = coeff_tensor[σ, λ, μ, ν]
184183
iszero(coeff) && continue
185184
vertices′ = TupleTools.setindex(vertices, σ, i - 1)
186185
vertices′ = TupleTools.setindex(vertices′, λ, i)
@@ -193,10 +192,26 @@ function artin_braid(src::FusionTreeBlock{I, N, 0}, i; inv::Bool = false) where
193192
end
194193
end
195194
end
196-
197195
return dst => U
198196
end
199197

198+
function _artin_braid_local((a, b, c, d, e, c′)::NTuple{6, I}, inv::Bool) where {I}
199+
if FusionStyle(I) isa MultiplicityFreeFusion
200+
coeff = if inv
201+
conj(Rsymbol(d, c, e) * Fsymbol(d, a, b, e, c′, c)) * Rsymbol(d, a, c′)
202+
else
203+
Rsymbol(c, d, e) * conj(Fsymbol(d, a, b, e, c′, c) * Rsymbol(a, d, c′))
204+
end
205+
return coeff
206+
else
207+
Rmat1 = inv ? Rsymbol(d, c, e)' : Rsymbol(c, d, e)
208+
Rmat2 = inv ? Rsymbol(d, a, c′)' : Rsymbol(a, d, c′)
209+
Fmat = Fsymbol(d, a, b, e, c′, c)
210+
@tensor coeff[σ, λ, μ, ν] := Rmat1[ν, ρ] * conj(Fmat[κ, λ, μ, ρ]) * conj(Rmat2[σ, κ])
211+
return coeff
212+
end
213+
end
214+
200215
# braid fusion tree
201216
"""
202217
braid(f::FusionTree{<:Sector, N}, p::NTuple{N, Int}, levels::NTuple{N, Int})
@@ -223,7 +238,7 @@ function braid(f::FusionTree{I, N}, (p, _)::Index2Tuple{N, 0}, (levels, _)::Inde
223238
for j in 1:(i - 1)
224239
if p[j] > p[i]
225240
a, b = f.uncoupled[p[j]], f.uncoupled[p[i]]
226-
coeff *= Rsymbol(a, b, first(a b))
241+
coeff *= Rsymbol(a, b, only(a b))
227242
end
228243
end
229244
end
@@ -308,19 +323,22 @@ end
308323
f′, coeff2 = braid(f, p, levels)
309324
(f₁′, f₂′), coeff3 = repartition((f′, f0), N₁)
310325
return (f₁′, f₂′) => coeff1 * coeff2 * coeff3
311-
312326
else
313327
src, (p1, p2), (l1, l2) = key
314328

315329
p = linearizepermutation(p1, p2, numout(src), numin(src))
316330
levels = (l1..., reverse(l2)...)
317331

318-
dst, U = repartition(src, numind(src))
332+
dst, U′ = repartition(src, numind(src))
333+
T = sectorscalartype(I)
334+
U = eltype(U′) == T ? U′ : T.(U′) # U′ has fusionscalartype(I) elements
335+
Uold = similar(U)
319336

320337
for s in permutation2swaps(p)
338+
U, Uold = Uold, U
321339
inv = levels[s] > levels[s + 1]
322340
dst, U_tmp = artin_braid(dst, s; inv)
323-
U = U_tmp * U
341+
U = mul!(U, U_tmp, Uold)
324342
l = levels[s]
325343
levels = TupleTools.setindex(levels, levels[s + 1], s)
326344
levels = TupleTools.setindex(levels, l, s + 1)
@@ -329,8 +347,9 @@ end
329347
if N₂ == 0
330348
return dst => U
331349
else
350+
U, Uold = Uold, U
332351
dst, U_tmp = repartition(dst, N₁)
333-
U = U_tmp * U
352+
U = mul!(U, U_tmp, Uold)
334353
return dst => U
335354
end
336355
end

src/spaces/productspace.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,11 +278,7 @@ function insertleftunit(
278278
u = unitspace(spacetype(P))
279279
else
280280
N > 0 || throw(ArgumentError("cannot insert a sensible unit space in the empty product space"))
281-
if i == N + 1
282-
u = rightunitspace(P[N])
283-
else
284-
u = leftunitspace(P[i])
285-
end
281+
u = (i == N + 1) ? rightunitspace(P[N]) : leftunitspace(P[i])
286282
end
287283
if dual
288284
u = TensorKit.dual(u)
@@ -312,7 +308,7 @@ function insertrightunit(
312308
u = unitspace(spacetype(P))
313309
else
314310
N > 0 || throw(ArgumentError("cannot insert a sensible unit space in the empty product space"))
315-
u = rightunitspace(P[i])
311+
u = (i == 0 ) ? leftunitspace(P[1]) : rightunitspace(P[i])
316312
end
317313
if dual
318314
u = TensorKit.dual(u)

test/factorizations/eig.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ for V in spacelist
1717
@timedtestset "Factorizations with symmetry: $Istr" verbose = true begin
1818
V1, V2, V3, V4, V5 = V
1919
W = V1 V2
20-
@assert !isempty(blocksectors(W))
21-
@assert !isempty(intersect(blocksectors(V4), blocksectors(W)))
2220

2321
@testset "Eigenvalue decomposition" begin
2422
for T in eltypes,

test/factorizations/ortho.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,11 @@ for V in spacelist
1717
@timedtestset "Factorizations with symmetry: $Istr" verbose = true begin
1818
V1, V2, V3, V4, V5 = V
1919
W = V1 V2
20-
@assert !isempty(blocksectors(W))
21-
@assert !isempty(intersect(blocksectors(V4), blocksectors(W)))
2220

2321
@testset "QR decomposition" begin
2422
for T in eltypes,
2523
t in (
26-
rand(T, W, W), rand(T, W, W)', rand(T, W, V4), rand(T, V4, W)',
24+
rand(T, W, W), rand(T, W, W)', rand(T, (V1 V2 V3), (V4 V5)'), rand(T, (V1 V2)', (V3 V4 V5))',
2725
DiagonalTensorMap(rand(T, reduceddim(V1)), V1),
2826
)
2927

@@ -76,7 +74,7 @@ for V in spacelist
7674
@testset "LQ decomposition" begin
7775
for T in eltypes,
7876
t in (
79-
rand(T, W, W), rand(T, W, W)', rand(T, W, V4), rand(T, V4, W)',
77+
rand(T, W, W), rand(T, W, W)', rand(T, (V1 V2), (V3 V4 V5)'), rand(T, (V1 V2 V3)', (V4 V5))',
8078
DiagonalTensorMap(rand(T, reduceddim(V1)), V1),
8179
)
8280

test/factorizations/projections.jl

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,46 +17,6 @@ for V in spacelist
1717
@timedtestset "Factorizations with symmetry: $Istr" verbose = true begin
1818
V1, V2, V3, V4, V5 = V
1919
W = V1 V2
20-
@assert !isempty(blocksectors(W))
21-
@assert !isempty(intersect(blocksectors(V4), blocksectors(W)))
22-
23-
@testset "Condition number and rank" begin
24-
for T in eltypes,
25-
t in (
26-
rand(T, W, W), rand(T, W, W)',
27-
rand(T, W, V4), rand(T, V4, W),
28-
rand(T, W, V4)', rand(T, V4, W)',
29-
DiagonalTensorMap(rand(T, reduceddim(V1)), V1),
30-
)
31-
32-
d1, d2 = dim(codomain(t)), dim(domain(t))
33-
r = rank(t)
34-
@test r == min(d1, d2)
35-
@test typeof(r) == typeof(d1)
36-
M = left_null(t)
37-
@test @constinferred(rank(M)) + r d1
38-
Mᴴ = right_null(t)
39-
@test rank(Mᴴ) + r d2
40-
end
41-
for T in eltypes
42-
u = unitary(T, V1 V2, V1 V2)
43-
@test @constinferred(cond(u)) one(real(T))
44-
@test @constinferred(rank(u)) == dim(V1 V2)
45-
46-
t = rand(T, zerospace(V1), W)
47-
@test rank(t) == 0
48-
t2 = rand(T, zerospace(V1) * zerospace(V2), zerospace(V1) * zerospace(V2))
49-
@test rank(t2) == 0
50-
@test cond(t2) == 0.0
51-
end
52-
for T in eltypes, t in (rand(T, W, W), rand(T, W, W)')
53-
project_hermitian!(t)
54-
vals = @constinferred LinearAlgebra.eigvals(t)
55-
λmax = maximum(s -> maximum(abs, s), values(vals))
56-
λmin = minimum(s -> minimum(abs, s), values(vals))
57-
@test cond(t) λmax / λmin
58-
end
59-
end
6020

6121
@testset "Hermitian projections" begin
6222
for T in eltypes,
@@ -92,8 +52,7 @@ for V in spacelist
9252
@testset "Isometric projections" begin
9353
for T in eltypes,
9454
t in (
95-
randn(T, W, W), randn(T, W, W)',
96-
randn(T, W, V4), randn(T, V4, W)',
55+
rand(T, W, W), rand(T, W, W)', rand(T, (V1 V2 V3), (V4 V5)'), rand(T, (V1 V2)', (V3 V4 V5))',
9756
)
9857
t2 = project_isometric(t)
9958
@test isisometric(t2)

test/factorizations/svd.jl

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,49 @@ for V in spacelist
1717
@timedtestset "Factorizations with symmetry: $Istr" verbose = true begin
1818
V1, V2, V3, V4, V5 = V
1919
W = V1 V2
20-
@assert !isempty(blocksectors(W))
21-
@assert !isempty(intersect(blocksectors(V4), blocksectors(W)))
20+
21+
@testset "Condition number and rank" begin
22+
for T in eltypes,
23+
t in (
24+
rand(T, W, W), rand(T, W, W)',
25+
rand(T, (V1 V2 V3), (V4 V5)'), rand(T, (V1 V2)', (V3 V4 V5))',
26+
rand(T, (V1 V2), (V3 V4 V5)'), rand(T, (V1 V2 V3)', (V4 V5))',
27+
DiagonalTensorMap(rand(T, reduceddim(V1)), V1),
28+
)
29+
30+
d1, d2 = dim(codomain(t)), dim(domain(t))
31+
r = rank(t)
32+
@test r == min(d1, d2)
33+
@test typeof(r) == typeof(d1)
34+
M = left_null(t)
35+
@test @constinferred(rank(M)) + r d1
36+
Mᴴ = right_null(t)
37+
@test rank(Mᴴ) + r d2
38+
end
39+
for T in eltypes
40+
u = unitary(T, V1 V2, V1 V2)
41+
@test @constinferred(cond(u)) one(real(T))
42+
@test @constinferred(rank(u)) == dim(V1 V2)
43+
44+
t = rand(T, zerospace(V1), W)
45+
@test rank(t) == 0
46+
t2 = rand(T, zerospace(V1) * zerospace(V2), zerospace(V1) * zerospace(V2))
47+
@test rank(t2) == 0
48+
@test cond(t2) == 0.0
49+
end
50+
for T in eltypes, t in (rand(T, W, W), rand(T, W, W)')
51+
project_hermitian!(t)
52+
vals = @constinferred LinearAlgebra.eigvals(t)
53+
λmax = maximum(s -> maximum(abs, s), values(vals))
54+
λmin = minimum(s -> minimum(abs, s), values(vals))
55+
@test cond(t) λmax / λmin
56+
end
57+
end
2258

2359
@testset "Polar decomposition" begin
2460
for T in eltypes,
2561
t in (
26-
rand(T, W, W), rand(T, W, W)', rand(T, W, V4), rand(T, V4, W)',
62+
rand(T, W, W), rand(T, (V1 V2 V3), (V4 V5)'), rand(T, (V1 V2)', (V3 V4 V5))',
2763
DiagonalTensorMap(rand(T, reduceddim(V1)), V1),
2864
)
2965

@@ -39,7 +75,10 @@ for V in spacelist
3975
end
4076

4177
for T in eltypes,
42-
t in (rand(T, W, W), rand(T, W, W)', rand(T, V4, W), rand(T, W, V4)')
78+
t in (
79+
rand(T, W, W), rand(T, W, W)', rand(T, (V1 V2), (V3 V4 V5)'), rand(T, (V1 V2 V3)', (V4 V5))',
80+
DiagonalTensorMap(rand(T, reduceddim(V1)), V1),
81+
)
4382

4483
@assert codomain(t) domain(t)
4584
p, wᴴ = @constinferred right_polar(t)
@@ -57,8 +96,8 @@ for V in spacelist
5796
for T in eltypes,
5897
t in (
5998
rand(T, W, W), rand(T, W, W)',
60-
rand(T, W, V4), rand(T, V4, W),
61-
rand(T, W, V4)', rand(T, V4, W)',
99+
rand(T, (V1 V2 V3), (V4 V5)'), rand(T, (V1 V2)', (V3 V4 V5))',
100+
rand(T, (V1 V2), (V3 V4 V5)'), rand(T, (V1 V2 V3)', (V4 V5))',
62101
DiagonalTensorMap(rand(T, reduceddim(V1)), V1),
63102
)
64103

@@ -121,10 +160,10 @@ for V in spacelist
121160
@testset "truncated SVD" begin
122161
for T in eltypes,
123162
t in (
124-
randn(T, W, W), randn(T, W, W)',
125-
randn(T, W, V4), randn(T, V4, W),
126-
randn(T, W, V4)', randn(T, V4, W)',
127-
DiagonalTensorMap(randn(T, reduceddim(V1)), V1),
163+
rand(T, W, W), rand(T, W, W)',
164+
rand(T, (V1 V2 V3), (V4 V5)'), rand(T, (V1 V2)', (V3 V4 V5))',
165+
rand(T, (V1 V2), (V3 V4 V5)'), rand(T, (V1 V2 V3)', (V4 V5))',
166+
DiagonalTensorMap(rand(T, reduceddim(V1)), V1),
128167
)
129168

130169
@constinferred normalize!(t)

test/mooncake/factorizations.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ rng = Random.default_rng()
1313
spacelist = ad_spacelist(fast_tests)
1414
eltypes = (Float64, ComplexF64)
1515

16-
1716
@timedtestset "Mooncake - Factorizations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes
1817
atol = default_tol(T)
1918
rtol = default_tol(T)
@@ -31,7 +30,7 @@ eltypes = (Float64, ComplexF64)
3130
# TODO:
3231
# Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false)
3332

34-
A = randn(T, V[1] V[2] V[1])
33+
A = randn(T, V[1] V[2] V[3] (V[4] V[5])')
3534

3635
Mooncake.TestUtils.test_rule(rng, qr_compact, A; atol, rtol, mode, is_primitive = false)
3736

@@ -57,7 +56,7 @@ eltypes = (Float64, ComplexF64)
5756
# TODO:
5857
# Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false)
5958

60-
A = randn(T, V[1] V[2] V[1])
59+
A = randn(T, V[1] V[2] (V[3] V[4] V[5])')
6160

6261
Mooncake.TestUtils.test_rule(rng, lq_compact, A; atol, rtol, mode, is_primitive = false)
6362

@@ -86,7 +85,7 @@ eltypes = (Float64, ComplexF64)
8685
end
8786

8887
@timedtestset "Singular value decomposition" begin
89-
for t in (randn(T, V[1] V[1]), randn(T, V[1] V[2] V[3] V[4]))
88+
for t in (randn(T, V[1] V[1]), randn(T, V[1] V[2] (V[3] V[4] V[5])'))
9089
USVᴴ = svd_compact(t)
9190
ΔUSVᴴ = Mooncake.randn_tangent(rng, USVᴴ)
9291
remove_svdgauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...)

0 commit comments

Comments
 (0)