@@ -83,6 +83,23 @@ _add_adjoint(ex) = Expr(TO.prime, ex)
8383# spaces from the rest of the expression. Construct the explicit BraidingTensor objects and
8484# insert them in the expression.
8585function _construct_braidingtensors (ex)
86+ function filter_f (expr)
87+ if TO. istensor (expr)
88+ return _remove_adjoint (TO. decomposetensor (expr)[1 ]) != :τ
89+ elseif TO. istensorexpr (expr)
90+ return any (filter_f, expr. args)
91+ else
92+ return false
93+ end
94+ end
95+ function extract_tensors (tensor_ex)
96+ if TO. istensor (tensor_ex)
97+ return [TO. decomposetensor (tensor_ex)[1 ]]
98+ elseif TO. istensorexpr (tensor_ex)
99+ return collect (Iterators. flatmap (extract_tensors, filter (filter_f, tensor_ex. args)))
100+ end
101+ end
102+ # get storagetype
86103 ex isa Expr || return ex
87104 if ex. head == :macrocall && ex. args[1 ] == Symbol (" @notensor" )
88105 return ex
@@ -104,7 +121,9 @@ function _construct_braidingtensors(ex)
104121 )
105122 end
106123 end
107- newrhs, success = _construct_braidingtensors! (rhs, preargs, indexmap)
124+ # if this is a definition, the lhs tensor is NOT yet defined
125+ no_τ_ex = reduce (vcat, Iterators. flatmap (extract_tensors, filter (filter_f, rhs. args)); init = Symbol[])
126+ newrhs, success = _construct_braidingtensors! (rhs, preargs, indexmap, no_τ_ex)
108127 success ||
109128 throw (ArgumentError (" cannot determine the spaces of all braiding tensors in $ex " ))
110129 pre = Expr (
@@ -115,7 +134,8 @@ function _construct_braidingtensors(ex)
115134 elseif TO. istensorexpr (ex)
116135 preargs = Vector {Any} ()
117136 indexmap = Dict {Any, Any} ()
118- newex, success = _construct_braidingtensors! (ex, preargs, indexmap)
137+ no_τ_ex = reduce (vcat, Iterators. flatmap (extract_tensors, filter (filter_f, ex. args)); init = Symbol[])
138+ newex, success = _construct_braidingtensors! (ex, preargs, indexmap, no_τ_ex)
119139 success ||
120140 throw (ArgumentError (" cannot determine the spaces of all braiding tensors in $ex " ))
121141 pre = Expr (
@@ -128,7 +148,7 @@ function _construct_braidingtensors(ex)
128148 end
129149end
130150
131- function _construct_braidingtensors! (ex, preargs, indexmap) # ex is guaranteed to be a single tensor expression
151+ function _construct_braidingtensors! (ex, preargs, indexmap, non_braiding ) # ex is guaranteed to be a single tensor expression
132152 if TO. isscalarexpr (ex)
133153 # ex could be tensorscalar call with more braiding tensors
134154 return _construct_braidingtensors (ex), true
@@ -163,7 +183,9 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t
163183 end
164184 if foundV1 && foundV2
165185 s = gensym (:τ )
166- constructex = Expr (:call , GlobalRef (TensorKit, :BraidingTensor ), V1, V2)
186+ storageex = Expr (:call , GlobalRef (TensorKit, :promote_storagetype ), non_braiding... )
187+ braidingex = Expr (:call , GlobalRef (TensorKit, :braidingtensortype ), V1, V2, storageex)
188+ constructex = Expr (:call , braidingex, V1, V2)
167189 push! (preargs, Expr (:(= ), s, constructex))
168190 obj = _is_adjoint (obj) ? _add_adjoint (s) : s
169191 success = true
@@ -196,7 +218,7 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t
196218 newargs = Vector {Any} (undef, length (args))
197219 success = true
198220 for i in 1 : length (ex. args)
199- newargs[i], successa = _construct_braidingtensors! (args[i], preargs, indexmap)
221+ newargs[i], successa = _construct_braidingtensors! (args[i], preargs, indexmap, non_braiding )
200222 success = success && successa
201223 end
202224 newex = Expr (ex. head, newargs... )
@@ -212,7 +234,7 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t
212234 for i in 2 : length (ex. args)
213235 successes[i] && continue
214236 newargs[i], successa = _construct_braidingtensors! (
215- args[i], preargs, indexmap
237+ args[i], preargs, indexmap, non_braiding
216238 )
217239 successes[i] = successa
218240 end
@@ -232,7 +254,7 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t
232254 indices = [TO. getindices (arg) for arg in args]
233255 for i in 2 : length (ex. args)
234256 indexmapa = copy (indexmap)
235- newargs[i], successa = _construct_braidingtensors! (args[i], preargs, indexmapa)
257+ newargs[i], successa = _construct_braidingtensors! (args[i], preargs, indexmapa, non_braiding )
236258 for l in indices[i]
237259 if ! haskey (indexmap, l) && haskey (indexmapa, l)
238260 indexmap[l] = indexmapa[l]
@@ -243,10 +265,10 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t
243265 newex = Expr (ex. head, newargs... )
244266 return newex, success
245267 elseif isexpr (ex, :call ) && ex. args[1 ] == :/ && length (ex. args) == 3
246- newarg, success = _construct_braidingtensors! (ex. args[2 ], preargs, indexmap)
268+ newarg, success = _construct_braidingtensors! (ex. args[2 ], preargs, indexmap, non_braiding )
247269 return Expr (:call , :/ , newarg, ex. args[3 ]), success
248270 elseif isexpr (ex, :call ) && ex. args[1 ] == :\ && length (ex. args) == 3
249- newarg, success = _construct_braidingtensors! (ex. args[3 ], preargs, indexmap)
271+ newarg, success = _construct_braidingtensors! (ex. args[3 ], preargs, indexmap, non_braiding )
250272 return Expr (:call , :\ , ex. args[2 ], newarg), success
251273 else
252274 error (" unexpected expression $ex " )
0 commit comments