|
1 | 1 | module ITensorNetworksNextTensorOperationsExt |
2 | 2 |
|
3 | 3 | using BackendSelection: @Algorithm_str, Algorithm |
| 4 | +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArrays, ismul, symnameddims, |
| 5 | + substitute |
| 6 | +using ITensorNetworksNext.LazyNamedDimsArrays.TermInterface: arguments |
4 | 7 | using NamedDimsArrays: inds |
5 | | -using ITensorNetworksNext: ITensorNetworksNext, contraction_sequence_to_expr |
6 | 8 | using TensorOperations: TensorOperations, optimaltree |
7 | 9 |
|
8 | | -function ITensorNetworksNext.contraction_sequence(::Algorithm"optimal", tn::Vector{<:AbstractArray}) |
9 | | - network = collect.(inds.(tn)) |
10 | | - #Converting dims to Float64 to minimize overflow issues |
11 | | - inds_to_dims = Dict(i => Float64(length(i)) for i in unique(reduce(vcat, network))) |
12 | | - seq, _ = optimaltree(network, inds_to_dims) |
13 | | - return contraction_sequence_to_expr(seq) |
| 10 | +function contraction_tree_to_expr(f, tree) |
| 11 | + return if !(tree isa AbstractVector) |
| 12 | + f(tree) |
| 13 | + else |
| 14 | + prod(Base.Fix1(contraction_tree_to_expr, f), tree) |
| 15 | + end |
| 16 | +end |
| 17 | + |
| 18 | +function LazyNamedDimsArrays.optimize_contraction_order(alg::Algorithm"optimal", a) |
| 19 | + @assert ismul(a) |
| 20 | + ts = arguments(a) |
| 21 | + inds_network = collect.(inds.(ts)) |
| 22 | + # Converting dims to Float64 to minimize overflow issues |
| 23 | + inds_to_dims = Dict(i => Float64(length(i)) for i in reduce(∪, inds_network)) |
| 24 | + tree, _ = optimaltree(inds_network, inds_to_dims) |
| 25 | + return contraction_tree_to_expr(i -> ts[i], tree) |
14 | 26 | end |
15 | 27 |
|
16 | 28 | end |
0 commit comments