Skip to content

Commit 532ea7f

Browse files
kshyattJutho
andauthored
Apply suggestions from code review for preprocessor
Co-authored-by: Jutho <Jutho@users.noreply.github.com>
1 parent f7a7cd3 commit 532ea7f

1 file changed

Lines changed: 5 additions & 13 deletions

File tree

src/planar/preprocessors.jl

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -83,20 +83,12 @@ _add_adjoint(ex) = Expr(TO.prime, ex)
8383
# spaces from the rest of the expression. Construct the explicit BraidingTensor objects and
8484
# insert them in the expression.
8585
function _construct_braidingtensors(ex)
86-
function filter_f(expr)
87-
if TO.istensor(expr)
88-
return _remove_adjoint(TO.decomposetensor(expr)[1]) !=
89-
elseif TO.istensorexpr(expr)
90-
return any(filter_f, expr.args)
91-
else
92-
return false
93-
end
94-
end
9586
function extract_tensors(tensor_ex)
9687
if TO.istensor(tensor_ex)
97-
return [TO.decomposetensor(tensor_ex)[1]]
88+
obj = TO.decomposetensor(expr)[1]
89+
return (_remove_adjoint(obj) == ) ? Any[] : Any[obj]
9890
elseif TO.istensorexpr(tensor_ex)
99-
return collect(Iterators.flatmap(extract_tensors, filter(filter_f, tensor_ex.args)))
91+
return collect(Any, Iterators.flatmap(extract_tensors, tensor_ex.args))
10092
end
10193
end
10294
# get storagetype
@@ -122,7 +114,7 @@ function _construct_braidingtensors(ex)
122114
end
123115
end
124116
# if this is a definition, the lhs tensor is NOT yet defined
125-
no_τ_ex = reduce(vcat, Iterators.flatmap(extract_tensors, filter(filter_f, rhs.args)); init = Symbol[])
117+
no_τ_ex = collect(Any, Iterators.flatmap(extract_tensors, rhs.args))
126118
newrhs, success = _construct_braidingtensors!(rhs, preargs, indexmap, no_τ_ex)
127119
success ||
128120
throw(ArgumentError("cannot determine the spaces of all braiding tensors in $ex"))
@@ -134,7 +126,7 @@ function _construct_braidingtensors(ex)
134126
elseif TO.istensorexpr(ex)
135127
preargs = Vector{Any}()
136128
indexmap = Dict{Any, Any}()
137-
no_τ_ex = reduce(vcat, Iterators.flatmap(extract_tensors, filter(filter_f, ex.args)); init = Symbol[])
129+
no_τ_ex = collect(Any, Iterators.flatmap(extract_tensors, ex.args))
138130
newex, success = _construct_braidingtensors!(ex, preargs, indexmap, no_τ_ex)
139131
success ||
140132
throw(ArgumentError("cannot determine the spaces of all braiding tensors in $ex"))

0 commit comments

Comments
 (0)