Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ext/TensorKitChainRulesCoreExt/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ function ChainRulesCore.rrule(::typeof(twist), A::AbstractTensorMap, is; inv::Bo
return tA, twist_pullback
end

function ChainRulesCore.rrule(::typeof(flip), A::AbstractTensorMap, is; inv::Bool=false)
tA = flip(A, is; inv)
flip_pullback(ΔA) = NoTangent(), flip(unthunk(ΔA), is; inv=!inv), NoTangent()
return tA, flip_pullback
end

function ChainRulesCore.rrule(::typeof(dot), a::AbstractTensorMap, b::AbstractTensorMap)
dot_pullback(Δd) = NoTangent(), @thunk(b * Δd'), @thunk(a * Δd)
return dot(a, b), dot_pullback
Expand Down
30 changes: 26 additions & 4 deletions src/fusiontrees/manipulations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,24 +243,46 @@ end
# -> A-move (foldleft, foldright) is complicated, needs to be reexpressed in standard form

# flip a duality flag of a fusion tree
function flip(f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}, i::Int) where {I<:Sector,N₁,N₂}
function flip(f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}, i::Int;
inv::Bool=false) where {I<:Sector,N₁,N₂}
@assert 0 < i ≤ N₁ + N₂
if i ≤ N₁
a = f₁.uncoupled[i]
fs = frobeniusschur(a) * twist(a)
factor = f₁.isdual[i] ? fs : one(fs)
χₐ = frobeniusschur(a)
θₐ = twist(a)
if !inv
factor = f₁.isdual[i] ? χₐ * θₐ : one(θₐ)
else
factor = f₁.isdual[i] ? one(θₐ) : χₐ * conj(θₐ)
end
isdual′ = TupleTools.setindex(f₁.isdual, !f₁.isdual[i], i)
f₁′ = FusionTree{I}(f₁.uncoupled, f₁.coupled, isdual′, f₁.innerlines, f₁.vertices)
return SingletonDict((f₁′, f₂) => factor)
else
i -= N₁
a = f₂.uncoupled[i]
factor = f₂.isdual[i] ? frobeniusschur(a) : twist(a)
χₐ = frobeniusschur(a)
θₐ = twist(a)
if !inv
factor = f₂.isdual[i] ? χₐ * one(θₐ) : θₐ
else
factor = f₂.isdual[i] ? conj(θₐ) : χₐ * one(θₐ)
end
isdual′ = TupleTools.setindex(f₂.isdual, !f₂.isdual[i], i)
f₂′ = FusionTree{I}(f₂.uncoupled, f₂.coupled, isdual′, f₂.innerlines, f₂.vertices)
return SingletonDict((f₁, f₂′) => factor)
end
end
function flip(f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}, ind;
inv::Bool=false) where {I<:Sector,N₁,N₂}
f₁′, f₂′ = f₁, f₂
factor = one(sectorscalartype(I))
for i in ind
(f₁′, f₂′), s = only(flip(f₁′, f₂′, i; inv))
factor *= s
end
return SingletonDict((f₁′, f₂′) => factor)
end

# change to N₁ - 1, N₂ + 1
function bendright(f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}) where {I<:Sector,N₁,N₂}
Expand Down
16 changes: 9 additions & 7 deletions src/tensors/indexmanipulations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@

Return a new tensor that is isomorphic to `t` but where the arrows on the indices `i` that satisfy
`i ∈ I` are flipped, i.e. `space(t′, i) = flip(space(t, i))`.

!!! note
The isomorphism that `flip` applies to each of the indices `i ∈ I` is such that flipping two indices
that are afterwards contracted within an `@tensor` contraction will yield the same result as without
flipping those indices first. However, `flip` is not involutory, i.e. `flip(flip(t, I), I) != t` in
general. To obtain the original tensor, one can use the `inv` keyword, i.e. it holds that
`flip(flip(t, I), I; inv=true) == t`.
"""
function flip(t::AbstractTensorMap, I)
function flip(t::AbstractTensorMap, I; inv::Bool=false)
P = flip(space(t), I)
t′ = similar(t, P)
for (f₁, f₂) in fusiontrees(t)
f₁′, f₂′ = f₁, f₂
factor = one(scalartype(t))
for i in I
(f₁′, f₂′), s = only(flip(f₁′, f₂′, i))
factor *= s
end
(f₁′, f₂′), factor = only(flip(f₁, f₂, I; inv))
scale!(t′[f₁′, f₂′], t[f₁, f₂], factor)
end
return t′
Expand Down
3 changes: 3 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
test_rrule(twist, A, 1)
test_rrule(twist, A, [1, 3])

test_rrule(flip, A, 1)
test_rrule(flip, A, [1, 3, 4])

D = randn(T, V[1] ⊗ V[2] ← V[3])
E = randn(T, V[4] ← V[5])
symmetricbraiding && test_rrule(⊗, D, E)
Expand Down
7 changes: 7 additions & 0 deletions test/tensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,13 @@ for V in spacelist
@test HrA12array ≈ convert(Array, HrA12)
end
end
@timedtestset "Index flipping: test flipping inverse" begin
t = rand(ComplexF64, V1 ⊗ V1' ← V1' ⊗ V1)
for i in 1:4
@test t ≈ flip(flip(t, i), i; inv=true)
@test t ≈ flip(flip(t, i; inv=true), i)
end
end
@timedtestset "Index flipping: test via explicit flip" begin
t = rand(ComplexF64, V1 ⊗ V1' ← V1' ⊗ V1)
F1 = unitary(flip(V1), V1)
Expand Down