Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ext/TensorKitChainRulesCoreExt/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@
return tA, twist_pullback
end

function ChainRulesCore.rrule(::typeof(flip), A::AbstractTensorMap, is; inv::Bool=false)
tA = flip(A, is; inv)
flip_pullback(ΔA) = NoTangent(), flip(unthunk(ΔA), is; inv=!inv), NoTangent()
return tA, flip_pullback

Check warning on line 89 in ext/TensorKitChainRulesCoreExt/linalg.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/linalg.jl#L86-L89

Added lines #L86 - L89 were not covered by tests
end

function ChainRulesCore.rrule(::typeof(dot), a::AbstractTensorMap, b::AbstractTensorMap)
dot_pullback(Δd) = NoTangent(), @thunk(b * Δd'), @thunk(a * Δd)
return dot(a, b), dot_pullback
Expand Down
30 changes: 26 additions & 4 deletions src/fusiontrees/manipulations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,24 +243,46 @@ end
# -> A-move (foldleft, foldright) is complicated, needs to be reexpressed in standard form

# flip a duality flag of a fusion tree
function flip(f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}, i::Int) where {I<:Sector,N₁,N₂}
function flip(f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}, i::Int;
inv::Bool=false) where {I<:Sector,N₁,N₂}
@assert 0 < i ≤ N₁ + N₂
if i ≤ N₁
a = f₁.uncoupled[i]
fs = frobeniusschur(a) * twist(a)
factor = f₁.isdual[i] ? fs : one(fs)
χₐ = frobeniusschur(a)
θₐ = twist(a)
if !inv
factor = f₁.isdual[i] ? χₐ * θₐ : one(θₐ)
else
factor = f₁.isdual[i] ? one(θₐ) : χₐ * conj(θₐ)
end
isdual′ = TupleTools.setindex(f₁.isdual, !f₁.isdual[i], i)
f₁′ = FusionTree{I}(f₁.uncoupled, f₁.coupled, isdual′, f₁.innerlines, f₁.vertices)
return SingletonDict((f₁′, f₂) => factor)
else
i -= N₁
a = f₂.uncoupled[i]
factor = f₂.isdual[i] ? frobeniusschur(a) : twist(a)
χₐ = frobeniusschur(a)
θₐ = twist(a)
if !inv
factor = f₂.isdual[i] ? χₐ * one(θₐ) : θₐ
else
factor = f₂.isdual[i] ? conj(θₐ) : χₐ * one(θₐ)
end
isdual′ = TupleTools.setindex(f₂.isdual, !f₂.isdual[i], i)
f₂′ = FusionTree{I}(f₂.uncoupled, f₂.coupled, isdual′, f₂.innerlines, f₂.vertices)
return SingletonDict((f₁, f₂′) => factor)
end
end
function flip(f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}, ind;
inv::Bool=false) where {I<:Sector,N₁,N₂}
f₁′, f₂′ = f₁, f₂
factor = one(sectorscalartype(I))
for i in ind
(f₁′, f₂′), s = only(flip(f₁′, f₂′, i; inv))
factor *= s
end
return SingletonDict((f₁′, f₂′) => factor)
end

# change to N₁ - 1, N₂ + 1
function bendright(f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}) where {I<:Sector,N₁,N₂}
Expand Down
16 changes: 9 additions & 7 deletions src/tensors/indexmanipulations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@

Return a new tensor that is isomorphic to `t` but where the arrows on the indices `i` that satisfy
`i ∈ I` are flipped, i.e. `space(t′, i) = flip(space(t, i))`.

!!! note
The isomorphism that `flip` applies to each of the indices `i ∈ I` is such that flipping two indices
that are afterwards contracted within an `@tensor` contraction will yield the same result as without
flipping those indices first. However, `flip` is not involutary, i.e. `flip(flip(t, I), I) != t` in
Comment thread
sanderdemeyer marked this conversation as resolved.
Outdated
general. To obtain the original tensor, one can use the `inv` keyword, i.e. it holds that
`flip(flip(t, I), I; inv=true) == t`.
"""
function flip(t::AbstractTensorMap, I)
function flip(t::AbstractTensorMap, I; inv::Bool=false)
P = flip(space(t), I)
t′ = similar(t, P)
for (f₁, f₂) in fusiontrees(t)
f₁′, f₂′ = f₁, f₂
factor = one(scalartype(t))
for i in I
(f₁′, f₂′), s = only(flip(f₁′, f₂′, i))
factor *= s
end
(f₁′, f₂′), factor = only(flip(f₁, f₂, I; inv))
scale!(t′[f₁′, f₂′], t[f₁, f₂], factor)
end
return t′
Expand Down
3 changes: 3 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
test_rrule(twist, A, 1)
test_rrule(twist, A, [1, 3])

test_rrule(flip, A, 1)
test_rrule(flip, A, [1, 3, 4])

D = randn(T, V[1] ⊗ V[2] ← V[3])
E = randn(T, V[4] ← V[5])
symmetricbraiding && test_rrule(⊗, D, E)
Expand Down
7 changes: 7 additions & 0 deletions test/tensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,13 @@ for V in spacelist
@test HrA12array ≈ convert(Array, HrA12)
end
end
@timedtestset "Index flipping: test flipping inverse" begin
t = rand(ComplexF64, V1 ⊗ V1' ← V1' ⊗ V1)
for i in 1:4
@test t ≈ flip(flip(t, i), i; inv=true)
@test t ≈ flip(flip(t, i; inv=true), i)
end
end
@timedtestset "Index flipping: test via explicit flip" begin
t = rand(ComplexF64, V1 ⊗ V1' ← V1' ⊗ V1)
F1 = unitary(flip(V1), V1)
Expand Down
Loading