22# special (2,2) tensor that implements a standard braiding operation
33#= ===================================================================#
44"""
5- struct BraidingTensor{T,S<: IndexSpace} <: AbstractTensorMap{T, S, 2, 2}
5+ struct BraidingTensor{T, S <: IndexSpace, A <: DenseVector{T} } <: AbstractTensorMap{T, S, 2, 2}
66 BraidingTensor(V1::S, V2::S, adjoint::Bool=false) where {S<:IndexSpace}
7+ BraidingTensor{T, S, A}(V1::S, V2::S, adjoint::Bool=false) where {T, S, A}
78
89Specific subtype of [`AbstractTensorMap`](@ref) for representing the braiding tensor that
910braids the first input over the second input; its inverse can be obtained as the adjoint.
1011
1112It holds that `domain(BraidingTensor(V1, V2)) == V1 ⊗ V2` and
12- `codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`.
13+ `codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`. The storage type `TA`
14+ controls the array type of the braiding tensor used when indexing
15+ and multiplying with other tensors.
1316"""
14- struct BraidingTensor{T, S} <: AbstractTensorMap{T, S, 2, 2}
17+ struct BraidingTensor{T, S, A <: DenseVector{T} } <: AbstractTensorMap{T, S, 2, 2}
1518 V1:: S
1619 V2:: S
1720 adjoint:: Bool
18- function BraidingTensor {T, S} (V1:: S , V2:: S , adjoint:: Bool = false ) where {T, S <: IndexSpace }
19- for a in sectors (V1)
20- for b in sectors (V2)
21- for c in (a ⊗ b)
22- Nsymbol (a, b, c) == Nsymbol (b, a, c) ||
23- throw (ArgumentError (" Cannot define a braiding between $a and $b " ))
24- end
25- end
21+ function BraidingTensor {T, S, A} (V1:: S , V2:: S , adjoint:: Bool = false ) where {T, S <: IndexSpace , A <: DenseVector{T} }
22+ for a in sectors (V1), b in sectors (V2), c in (a ⊗ b)
23+ Nsymbol (a, b, c) == Nsymbol (b, a, c) ||
24+ throw (ArgumentError (" Cannot define a braiding between $a and $b " ))
2625 end
27- return new {T, S} (V1, V2, adjoint)
26+ return new {T, S, A } (V1, V2, adjoint)
2827 # partial construction: only construct rowr and colr when needed
2928 end
3029end
3130function BraidingTensor {T} (V1:: S , V2:: S , adjoint:: Bool = false ) where {T, S <: IndexSpace }
32- return BraidingTensor {T, S} (V1, V2, adjoint)
31+ return braidingtensortype (S, T) (V1, V2, adjoint)
3332end
34- function BraidingTensor {T} (V1:: IndexSpace , V2:: IndexSpace , adjoint:: Bool = false ) where {T}
35- return BraidingTensor {T} (promote (V1, V2)... , adjoint)
33+ function BraidingTensor (V1:: S , V2:: S , adjoint:: Bool = false ) where {S <: IndexSpace }
34+ T = BraidingStyle (sectortype (S)) isa SymmetricBraiding ? Float64 : ComplexF64
35+ return BraidingTensor {T} (V1, V2, adjoint)
3636end
3737function BraidingTensor (V1:: IndexSpace , V2:: IndexSpace , adjoint:: Bool = false )
3838 return BraidingTensor (promote (V1, V2)... , adjoint)
3939end
40- function BraidingTensor (V1:: S , V2:: S , adjoint:: Bool = false ) where {S <: IndexSpace }
41- T = BraidingStyle (sectortype (S)) isa SymmetricBraiding ? Float64 : ComplexF64
42- return BraidingTensor {T, S} (V1, V2, adjoint)
43- end
4440function BraidingTensor (V:: HomSpace , adjoint:: Bool = false )
4541 domain (V) == reverse (codomain (V)) ||
4642 throw (SpaceMismatch (" Cannot define a braiding on $V " ))
4743 return BraidingTensor (V[2 ], V[1 ], adjoint)
4844end
45+ function BraidingTensor {T, S, A} (V:: HomSpace , adjoint:: Bool = false ) where {T, S, A}
46+ domain (V) == reverse (codomain (V)) ||
47+ throw (SpaceMismatch (" Cannot define a braiding on $V " ))
48+ return BraidingTensor {T, S, A} (V[2 ], V[1 ], adjoint)
49+ end
4950function BraidingTensor {T} (V:: HomSpace , adjoint:: Bool = false ) where {T}
5051 domain (V) == reverse (codomain (V)) ||
5152 throw (SpaceMismatch (" Cannot define a braiding on $V " ))
5253 return BraidingTensor {T} (V[2 ], V[1 ], adjoint)
5354end
54- function Base. adjoint (b:: BraidingTensor{T, S} ) where {T, S}
55- return BraidingTensor {T, S} (b. V1, b. V2, ! b. adjoint)
56- end
5755
58- space (b:: BraidingTensor ) = b. adjoint ? b. V1 ⊗ b. V2 ← b. V2 ⊗ b. V1 : b. V2 ⊗ b. V1 ← b. V1 ⊗ b. V2
56+ function Base. adjoint (b:: BraidingTensor{T, S, A} ) where {T, S, A}
57+ return BraidingTensor {T, S, A} (b. V1, b. V2, ! b. adjoint)
58+ end
5959
60- # specializations to ignore the storagetype of BraidingTensor
61- promote_storagetype (:: Type{A} , :: Type{B} ) where {A <: BraidingTensor , B <: AbstractTensorMap } = storagetype (B)
62- promote_storagetype (:: Type{A} , :: Type{B} ) where {A <: AbstractTensorMap , B <: BraidingTensor } = storagetype (A)
63- promote_storagetype (:: Type{A} , :: Type{B} ) where {A <: BraidingTensor , B <: BraidingTensor } = storagetype (A)
60+ # these are here to make the preprocessing for `@planar` expressions less painful
61+ function braidingtensortype (:: Type{S} , :: Type{TorA} ) where {S <: IndexSpace , TorA}
62+ A = similarstoragetype (TorA)
63+ return BraidingTensor{scalartype (A), S, A}
64+ end
65+ braidingtensortype (V:: S , :: Type{TorA} ) where {S <: IndexSpace , TorA} = braidingtensortype (S, TorA)
66+ braidingtensortype (V1:: S , V2:: S , :: Type{TorA} ) where {S <: IndexSpace , TorA} = braidingtensortype (S, TorA)
67+ function braidingtensortype (V1:: IndexSpace , V2:: IndexSpace , :: Type{TorA} ) where {TorA}
68+ S = promote (V1, V2)
69+ return braidingtensortype (S... , TorA)
70+ end
71+ function braidingtensortype (V:: HomSpace , :: Type{TorA} ) where {TorA}
72+ return braidingtensortype (spacetype (V), TorA)
73+ end
6474
65- promote_storagetype (:: Type{T} , :: Type{A} , :: Type{B} ) where {T <: Number , A <: BraidingTensor , B <: AbstractTensorMap } =
66- similarstoragetype (B, T)
67- promote_storagetype (:: Type{T} , :: Type{A} , :: Type{B} ) where {T <: Number , A <: AbstractTensorMap , B <: BraidingTensor } =
68- similarstoragetype (A, T)
69- promote_storagetype (:: Type{T} , :: Type{A} , :: Type{B} ) where {T <: Number , A <: BraidingTensor , B <: BraidingTensor } =
70- similarstoragetype (A, T)
75+ storagetype (:: Type{BraidingTensor{T, S, A}} ) where {T, S, A} = A
76+ space (b:: BraidingTensor ) = b. adjoint ? b. V1 ⊗ b. V2 ← b. V2 ⊗ b. V1 : b. V2 ⊗ b. V1 ← b. V1 ⊗ b. V2
7177
7278function Base. getindex (b:: BraidingTensor )
7379 sectortype (b) === Trivial || throw (SectorMismatch ())
@@ -99,6 +105,13 @@ function _braiding_factor(f₁, f₂, inv::Bool = false)
99105 return r
100106end
101107
108+ # generates scalar indexing errors on GPU
109+ function fill_braidingsubblock! (data, val)
110+ f (I) = ((I[1 ] == I[4 ]) & (I[2 ] == I[3 ])) * val
111+ return data .= f .(CartesianIndices (data))
112+ end
113+
114+
102115@inline function subblock (
103116 b:: BraidingTensor , (f₁, f₂):: Tuple{FusionTree{I, 2}, FusionTree{I, 2}}
104117 ) where {I <: Sector }
@@ -113,17 +126,10 @@ end
113126 throw (SectorMismatch ())
114127 end
115128 d = (dims (codomain (b), f₁. uncoupled)... , dims (domain (b), f₂. uncoupled)... )
116- n1 = d[1 ] * d[2 ]
117- n2 = d[3 ] * d[4 ]
118- data = sreshape (StridedView (Matrix {eltype(b)} (undef, n1, n2)), d)
119- fill! (data, zero (eltype (b)))
120-
129+ data_parent = storagetype (b)(undef, prod (d))
130+ data = sreshape (StridedView (data_parent), d)
121131 r = _braiding_factor (f₁, f₂, b. adjoint)
122- if ! isnothing (r)
123- @inbounds for i in axes (data, 1 ), j in axes (data, 2 )
124- data[i, j, j, i] = r
125- end
126- end
132+ isnothing (r) ? zerovector! (data) : fill_braidingsubblock! (data, r)
127133 return data
128134end
129135
@@ -134,49 +140,50 @@ TensorMap(b::BraidingTensor) = copy!(similar(b), b)
134140Base. convert (:: Type{TensorMap} , b:: BraidingTensor ) = TensorMap (b)
135141
136142Base. complex (b:: BraidingTensor{<:Complex} ) = b
137- function Base. complex (b:: BraidingTensor )
138- return BraidingTensor {complex(scalartype(b))} (space (b), b. adjoint)
143+ function Base. complex (b:: BraidingTensor{T, S, A} ) where {T, S, A}
144+ Tc = complex (T)
145+ Ac = similarstoragetype (A, Tc)
146+ return BraidingTensor {Tc, S, Ac} (space (b), b. adjoint)
139147end
140148
141- function block (b:: BraidingTensor , s:: Sector )
142- I = sectortype (b)
143- I == typeof (s) || throw (SectorMismatch ())
144-
145- # TODO : probably always square?
146- m = blockdim (codomain (b), s)
147- n = blockdim (domain (b), s)
148- data = Matrix {eltype(b)} (undef, (m, n))
149-
150- length (data) == 0 && return data # s ∉ blocksectors(b)
151-
152- data = fill! (data, zero (eltype (b)))
153-
149+ # Trivial
150+ function fill_braidingblock! (data, b:: BraidingTensor , s:: Trivial )
154151 V1, V2 = codomain (b)
155- if sectortype (b) === Trivial
156- d1, d2 = dim (V1), dim (V2)
157- subblock = sreshape (StridedView (data), (d1, d2, d2, d1))
158- @inbounds for i in axes (subblock, 1 ), j in axes (subblock, 2 )
159- subblock[i, j, j, i] = one (eltype (b))
160- end
161- return data
162- end
152+ d1, d2 = dim (V1), dim (V2)
153+ subblock = sreshape (StridedView (data), (d1, d2, d2, d1))
154+ fill_braidingsubblock! (subblock, one (eltype (b)))
155+ return data
156+ end
163157
158+ # Nontrivial
159+ function fill_braidingblock! (data, b:: BraidingTensor , s:: Sector )
164160 base_offset = first (blockstructure (b)[s][2 ]) - 1
165161
166162 for ((f₁, f₂), (sz, str, off)) in pairs (subblockstructure (space (b)))
167163 (f₁. coupled == f₂. coupled == s) || continue
168164 r = _braiding_factor (f₁, f₂, b. adjoint)
169- isnothing (r) && continue
170165 # change offset to account for single block
171166 subblock = StridedView (data, sz, str, off - base_offset)
172- @inbounds for i in axes (subblock, 1 ), j in axes (subblock, 2 )
173- subblock[i, j, j, i] = r
174- end
167+ isnothing (r) ? zerovector! (subblock) : fill_braidingsubblock! (subblock, r)
175168 end
176-
177169 return data
178170end
179171
172+ function block (b:: BraidingTensor , s:: Sector )
173+ I = sectortype (b)
174+ I == typeof (s) || throw (SectorMismatch ())
175+
176+ # TODO : probably always square?
177+ m = blockdim (codomain (b), s)
178+ n = blockdim (domain (b), s)
179+
180+ data = reshape (storagetype (b)(undef, m * n), (m, n))
181+
182+ m * n == 0 && return data # s ∉ blocksectors(b)
183+
184+ return fill_braidingblock! (data, b, s)
185+ end
186+
180187# Index manipulations
181188# -------------------
182189has_shared_permute (t:: BraidingTensor , :: Index2Tuple ) = false
0 commit comments