Skip to content

Commit 5eb73e8

Browse files
committed
Fix regression when constructing an ITensorNetwork using a TreeTensorNetwork resulting in empty tensors.
1 parent 1cbd0b5 commit 5eb73e8

3 files changed

Lines changed: 32 additions & 8 deletions

File tree

src/treetensornetworks/abstracttreetensornetwork.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ end
9494

9595
# For ambiguity error
9696
function Base.truncate(tn::AbstractTTN, edge::AbstractEdge; kwargs...)
97-
return typeof(tn)(truncate(itensornetwork(tn), edge; kwargs...))
97+
return typeof(tn)(truncate(ITensorNetwork(tn), edge; kwargs...))
9898
end
9999

100100
#
@@ -108,7 +108,7 @@ function NDTensors.contract(
108108
tn = copy(tn)
109109
# reverse post order vertices
110110
traversal_order = reverse(post_order_dfs_vertices(tn, root_vertex))
111-
return contract(itensornetwork(tn); sequence = traversal_order, kwargs...)
111+
return contract(ITensorNetwork(tn); sequence = traversal_order, kwargs...)
112112
# # forward post order edges
113113
# tn = copy(tn)
114114
# for e in post_order_dfs_edges(tn, root_vertex)

src/treetensornetworks/treetensornetwork.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,14 @@ end
8383

8484
# Field access
8585
"""
86-
itensornetwork(tn::TreeTensorNetwork) -> ITensorNetwork
86+
ITensorNetwork(tn::TreeTensorNetwork) -> ITensorNetwork
8787
8888
Convert a `TreeTensorNetwork` to a plain `ITensorNetwork`, discarding orthogonality
8989
metadata. The returned network shares the same underlying tensor data.
9090
9191
See also: [`TreeTensorNetwork`](@ref), [`ttn`](@ref).
9292
"""
93-
itensornetwork(tn::TTN) = getfield(tn, :tensornetwork)
93+
ITensorNetwork(tn::TTN) = copy(tn.tensornetwork)
9494

9595
"""
9696
ortho_region(tn::TreeTensorNetwork) -> Indices
@@ -99,10 +99,10 @@ Return the set of vertices that currently form the orthogonality center of `tn`.
9999
100100
See also: [`orthogonalize`](@ref).
101101
"""
102-
ortho_region(tn::TTN) = getfield(tn, :ortho_region)
102+
ortho_region(tn::TTN) = tn.ortho_region
103103

104104
# Required for `AbstractITensorNetwork` interface
105-
data_graph(tn::TTN) = data_graph(itensornetwork(tn))
105+
data_graph(tn::TTN) = data_graph(tn.tensornetwork)
106106

107107
function data_graph_type(G::Type{<:TTN})
108108
return data_graph_type(fieldtype(G, :tensornetwork))

test/test_ttns.jl

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using DataGraphs: vertex_data
22
using Graphs: vertices
3-
using ITensorNetworks: contract, ortho_region, siteinds, ttn
3+
using ITensorNetworks:
4+
ITensorNetwork, TreeTensorNetwork, contract, ortho_region, orthogonalize, siteinds, ttn
45
using ITensors: @disable_warn_order, random_itensor
56
using LinearAlgebra: norm
67
using NamedGraphs.NamedGraphGenerators: named_comb_tree
@@ -32,7 +33,30 @@ using Test: @test, @testset
3233
@test norm(S - S1) < 1.0e2 * cutoff
3334
end
3435

36+
@testset "Convert ITN <-> TTN" begin
37+
g = named_comb_tree((3, 2))
38+
sites = siteinds("S=1/2", g)
39+
40+
psi = ttn(sites) # zero-initialised
41+
psi = ttn(v -> "Up", sites) # product state
42+
43+
itn = ITensorNetwork(psi) # TTN → ITensorNetwork
44+
new_psi = TreeTensorNetwork(itn) # ITensorNetwork → TTN
45+
46+
@test !(new_psi === itn) # test we make a copy
47+
end
48+
3549
@testset "Ortho" begin
36-
# TODO
50+
g = named_comb_tree((3, 2))
51+
sites = siteinds("S=1/2", g)
52+
53+
psi = ttn(sites) # zero-initialised
54+
psi = ttn(v -> "Up", sites) # product state
55+
56+
v1 = collect(vertices(psi))[1]
57+
v2 = collect(vertices(psi))[2]
58+
59+
@test collect(ortho_region(orthogonalize(psi, v1))) == [v1]
60+
@test collect(ortho_region(orthogonalize(psi, [v1, v2]))) == [v1, v2]
3761
end
3862
end

0 commit comments

Comments
 (0)