33#= ===================================================================#
44"""
55 struct BraidingTensor{T,S<:IndexSpace,A<:DenseVector{T}} <: AbstractTensorMap{T, S, 2, 2}
6- BraidingTensor(V1::S, V2::S, ::Type{A}, adjoint::Bool=false) where {S<:IndexSpace, A <: DenseVector{<:Number}}
6+ BraidingTensor(V1::S, V2::S, adjoint::Bool=false) where {S<:IndexSpace}
7+ BraidingTensor{T, S, A}(V1::S, V2::S, adjoint::Bool=false) where {T, S, A}
78
89Specific subtype of [`AbstractTensorMap`](@ref) for representing the braiding tensor that
910braids the first input over the second input; its inverse can be obtained as the adjoint.
@@ -27,16 +28,11 @@ struct BraidingTensor{T, S, A} <: AbstractTensorMap{T, S, 2, 2}
2728 end
2829end
2930function BraidingTensor {T} (V1:: S , V2:: S , adjoint:: Bool = false ) where {T, S <: IndexSpace }
30- return BraidingTensor {T, S, Vector {T}} (V1, V2, adjoint)
31+ return BraidingTensor {T, S, similarstoragetype {T}} (V1, V2, adjoint)
3132end
3233function BraidingTensor (V1:: S , V2:: S , adjoint:: Bool = false ) where {S <: IndexSpace }
3334 T = BraidingStyle (sectortype (S)) isa SymmetricBraiding ? Float64 : ComplexF64
34- return BraidingTensor {T, S, Vector{T}} (V1, V2, adjoint)
35- end
36- # necessary due to HomSpace ctor below
37- function BraidingTensor (V1:: S , V2:: S , A, adjoint:: Bool = false ) where {S <: IndexSpace }
38- T = eltype (A)
39- return BraidingTensor {T, S, A} (V1, V2, adjoint)
35+ return BraidingTensor {T} (V1, V2, adjoint)
4036end
4137function BraidingTensor (V1:: IndexSpace , V2:: IndexSpace , adjoint:: Bool = false )
4238 return BraidingTensor (promote (V1, V2)... , adjoint)
@@ -46,10 +42,10 @@ function BraidingTensor(V::HomSpace, adjoint::Bool = false)
4642 throw (SpaceMismatch (" Cannot define a braiding on $V " ))
4743 return BraidingTensor (V[2 ], V[1 ], adjoint)
4844end
49- function BraidingTensor (V:: HomSpace , A, adjoint:: Bool = false )
45+ function BraidingTensor {T, S, A} (V:: HomSpace , adjoint:: Bool = false ) where {T, S, A}
5046 domain (V) == reverse (codomain (V)) ||
5147 throw (SpaceMismatch (" Cannot define a braiding on $V " ))
52- return BraidingTensor (V[2 ], V[1 ], A , adjoint)
48+ return BraidingTensor {T, S, A} (V[2 ], V[1 ], adjoint)
5349end
5450function BraidingTensor {T} (V:: HomSpace , adjoint:: Bool = false ) where {T}
5551 domain (V) == reverse (codomain (V)) ||
@@ -93,7 +89,8 @@ function _braiding_factor(f₁, f₂, inv::Bool = false)
9389 return r
9490end
9591
96- function _set_subblock! (data, val)
92+ # generates scalar indexing errors on GPU
93+ function fill_braidingsubblock! (data, val)
9794 f (I) = ((I[1 ] == I[4 ]) & (I[2 ] == I[3 ])) * val
9895 return data .= f .(CartesianIndices (data))
9996end
117114 n2 = d[3 ] * d[4 ]
118115 data_parent = storagetype (b)(undef, prod (d))
119116 data = sreshape (StridedView (data_parent), d)
120- fill! (data, zero (eltype (b)))
121-
122117 r = _braiding_factor (f₁, f₂, b. adjoint)
123- ! isnothing (r) && _set_subblock ! (data, r)
118+ ! isnothing (r) && fill_braidingsubblock ! (data, r)
124119 return data
125120end
126121
@@ -132,28 +127,33 @@ Base.convert(::Type{TensorMap}, b::BraidingTensor) = TensorMap(b)
132127
133128Base. complex (b:: BraidingTensor{<:Complex} ) = b
134129function Base. complex (b:: BraidingTensor{T, S, A} ) where {T, S, A}
135- Ac = similarstoragetype (A, complex (T))
136- return BraidingTensor (space (b), Ac, b. adjoint)
130+ Tc = complex (T)
131+ Ac = similarstoragetype (A, Tc)
132+ return BraidingTensor {Tc, S, Ac} (space (b), b. adjoint)
137133end
138134
139- function _trivial_subblock! (data, b:: BraidingTensor )
135+ # Trivial
136+ function fill_braidingblock! (data, b:: BraidingTensor , s:: Trivial )
140137 V1, V2 = codomain (b)
141138 d1, d2 = dim (V1), dim (V2)
142139 subblock = sreshape (StridedView (data), (d1, d2, d2, d1))
143- _set_subblock ! (subblock, one (eltype (b)))
140+ fill_braidingsubblock ! (subblock, one (eltype (b)))
144141 return data
145142end
146143
147- function _nontrivial_subblock! (data, b:: BraidingTensor , s:: Sector )
144+ # Nontrivial
145+ function fill_braidingblock! (data, b:: BraidingTensor , s:: Sector )
148146 base_offset = first (blockstructure (b)[s][2 ]) - 1
149147
150148 for ((f₁, f₂), (sz, str, off)) in pairs (subblockstructure (space (b)))
151149 (f₁. coupled == f₂. coupled == s) || continue
152150 r = _braiding_factor (f₁, f₂, b. adjoint)
153- isnothing (r) && continue
154151 # change offset to account for single block
155152 subblock = StridedView (data, sz, str, off - base_offset)
156- _set_subblock! (subblock, r)
153+ # without the zero-value, the non-trivial block is not set
154+ # correctly in the GPU case
155+ val = isnothing (r) ? zero (eltype (data)) : r
156+ fill_braidingsubblock! (subblock, val)
157157 end
158158 return data
159159end
@@ -169,13 +169,8 @@ function block(b::BraidingTensor, s::Sector)
169169 data = reshape (storagetype (b)(undef, m * n), (m, n))
170170
171171 m * n == 0 && return data # s ∉ blocksectors(b)
172- fill! (data, zero (eltype (b)))
173172
174- if sectortype (b) === Trivial
175- return _trivial_subblock! (data, b)
176- else
177- return _nontrivial_subblock! (data, b, s)
178- end
173+ return fill_braidingblock! (data, b, s)
179174end
180175
181176# Index manipulations
0 commit comments