Skip to content

Commit dc22cf8

Browse files
mtfishmanclaude
andcommitted
Drop _TreeTensorNetwork inner ctor; share unregister-inds helper
- TreeTensorNetwork now uses the auto-generated all-fields constructor plus a single outer ctor that performs the is_tree check, matching the ITensorNetwork constructor design. - Add ITensorNetwork{V}() empty ctor used as the seed for the tensor-collection constructor, replacing the explicit three-field call. - Extract _unregister_inds! so _set_vertex_data! and _rem_vertex! share their reverse-map cleanup, without altering the in-place vertex_data update path. - Build dictionaries via map(Indices(...)) in test/utils.jl. - Fix the TreeTensorNetwork jldoctest to use a properly connected two-vertex example. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 62b91ad commit dc22cf8

3 files changed

Lines changed: 67 additions & 86 deletions

File tree

src/itensornetwork.jl

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
using DataGraphs: DataGraphs, set_vertex_data!, underlying_graph, vertex_data
22
using Dictionaries: Dictionaries, Dictionary
3-
using Graphs: Graphs, add_edge!, add_vertex!, edges, has_edge, has_vertex, neighbors,
4-
rem_edge!, rem_vertex!, vertices
5-
using ITensors: ITensors, ITensor, Index, inds
6-
using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype
3+
using Graphs:
4+
Graphs, add_edge!, add_vertex!, has_edge, has_vertex, neighbors, rem_edge!, rem_vertex!
5+
using ITensors: ITensor, Index, inds
6+
using NamedGraphs: NamedGraphs, NamedGraph, vertextype
77

88
"""
99
ITensorNetwork{V}
@@ -69,14 +69,22 @@ end
6969
# Constructors
7070
#
7171

72+
# Empty network with no vertices, used as the starting point for the
73+
# tensor-collection constructor below.
74+
function ITensorNetwork{V}() where {V}
75+
return ITensorNetwork{V}(
76+
NamedGraph{V}(), Dictionary{V, ITensor}(), Dict{Index, Set{V}}()
77+
)
78+
end
79+
7280
# Construct by feeding `tensors` through `set_vertex_data!` one vertex
7381
# at a time — this centralizes the reverse-map registration, edge
7482
# inference, and hypergraph check in a single place (the `setindex!`
7583
# code path). Walking `keys(tensors)` in order makes the resulting
7684
# `neighbors(g, v)` / `edges(g)` iteration order deterministic in the
7785
# input order.
7886
function ITensorNetwork{V}(tensors) where {V}
79-
tn = ITensorNetwork{V}(NamedGraph{V}(), Dictionary{V, ITensor}(), Dict{Index, Set{V}}())
87+
tn = ITensorNetwork{V}()
8088
for v in keys(tensors)
8189
set_vertex_data!(tn, tensors[v], v)
8290
end
@@ -100,33 +108,43 @@ Base.copy(tn::ITensorNetwork) = ITensorNetwork(map(copy, vertex_data(tn)))
100108
# Mutation: keep `graph`, `vertex_data`, and `ind_to_vertices` in sync.
101109
#
102110

111+
# Drop the inds of `vertex_data[v]` from the reverse map, leaving
112+
# `vertex_data` and `graph` themselves untouched. Used both as a
113+
# prelude to overwriting `v` and as a step in `_rem_vertex!`.
114+
function _unregister_inds!(
115+
vertex_data::Dictionary{V, ITensor},
116+
ind_to_vertices::Dict{Index, Set{V}},
117+
v
118+
) where {V}
119+
haskey(vertex_data, v) || return nothing
120+
for i in inds(vertex_data[v])
121+
owners = ind_to_vertices[i]
122+
delete!(owners, v)
123+
isempty(owners) && delete!(ind_to_vertices, i)
124+
end
125+
return nothing
126+
end
127+
103128
# Write `value` to vertex `v`, updating the reverse map and reconciling
104-
# edges so the graph-edge ↔ shared-`Index` invariant holds. Cost is
105-
# O(deg(v) + |inds(value)|). If `v` isn't already in the network, it's
106-
# added — so this is also the natural way to grow the network one tensor
107-
# at a time without a separate `add_vertex!` step. Operates on raw
108-
# storage so `ITensorNetwork` and `TreeTensorNetwork` can share it.
129+
# edges so the graph-edge ↔ shared-`Index` invariant holds. If `v` is
130+
# new it's added to the graph — so this is also the natural way to grow
131+
# the network one tensor at a time without a separate `add_vertex!`
132+
# step. Operates on raw storage so `ITensorNetwork` and
133+
# `TreeTensorNetwork` can share it.
109134
function _set_vertex_data!(
110135
graph::NamedGraph{V},
111136
vertex_data::Dictionary{V, ITensor},
112137
ind_to_vertices::Dict{Index, Set{V}},
113138
value,
114139
v
115140
) where {V}
116-
# Add the vertex to the graph if it's new.
117141
has_vertex(graph, v) || add_vertex!(graph, v)
118-
# Unregister old inds of `vertex_data[v]` from the reverse map.
119-
if haskey(vertex_data, v)
120-
for i in inds(vertex_data[v])
121-
owners = ind_to_vertices[i]
122-
delete!(owners, v)
123-
isempty(owners) && delete!(ind_to_vertices, i)
124-
end
125-
end
126-
# Write the new tensor. `Dictionaries.set!` inserts or updates;
127-
# plain `setindex!` would error on a vertex not already in the dict.
142+
_unregister_inds!(vertex_data, ind_to_vertices, v)
143+
# `set!` updates in place when `v` is already present, preserving
144+
# the insertion order of `vertex_data`. Plain `setindex!` would
145+
# error on a missing key, and `insert!` would error on an existing
146+
# one — `set!` handles both branches.
128147
Dictionaries.set!(vertex_data, v, value)
129-
# Register new inds.
130148
for i in inds(value)
131149
push!(get!(ind_to_vertices, i, Set{V}()), v)
132150
length(ind_to_vertices[i]) <= 2 || error(
@@ -162,14 +180,8 @@ function _rem_vertex!(
162180
ind_to_vertices::Dict{Index, Set{V}},
163181
v
164182
) where {V}
165-
if haskey(vertex_data, v)
166-
for i in inds(vertex_data[v])
167-
owners = ind_to_vertices[i]
168-
delete!(owners, v)
169-
isempty(owners) && delete!(ind_to_vertices, i)
170-
end
171-
delete!(vertex_data, v)
172-
end
183+
_unregister_inds!(vertex_data, ind_to_vertices, v)
184+
haskey(vertex_data, v) && delete!(vertex_data, v)
173185
rem_vertex!(graph, v)
174186
return nothing
175187
end

src/treetensornetworks/treetensornetwork.jl

Lines changed: 17 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using DataGraphs: DataGraphs, set_vertex_data!, underlying_graph, vertex_data
2-
using Dictionaries: Dictionaries, Dictionary, Indices
3-
using Graphs: Graphs, add_vertex!, has_vertex, is_tree, rem_vertex!, vertices
2+
using Dictionaries: Dictionary, Indices
3+
using Graphs: Graphs, is_tree, rem_vertex!, vertices
44
using ITensors: ITensor, Index
55
using NamedGraphs.GraphsExtensions: vertextype
66
using NamedGraphs: NamedGraph
@@ -24,27 +24,6 @@ struct TreeTensorNetwork{V} <: AbstractTreeTensorNetwork{V}
2424
vertex_data::Dictionary{V, ITensor}
2525
ind_to_vertices::Dict{Index, Set{V}}
2626
ortho_region::Indices{V}
27-
global function _TreeTensorNetwork(
28-
graph::NamedGraph{V},
29-
vertex_data::Dictionary{V, ITensor},
30-
ind_to_vertices::Dict{Index, Set{V}},
31-
ortho_region::Indices{V}
32-
) where {V}
33-
@assert is_tree(graph)
34-
return new{V}(graph, vertex_data, ind_to_vertices, ortho_region)
35-
end
36-
end
37-
38-
function _TreeTensorNetwork(tn::ITensorNetwork{V}, ortho_region::Indices{V}) where {V}
39-
return _TreeTensorNetwork(tn.graph, tn.vertex_data, tn.ind_to_vertices, ortho_region)
40-
end
41-
42-
function _TreeTensorNetwork(tn::ITensorNetwork{V}, ortho_region) where {V}
43-
return _TreeTensorNetwork(tn, Indices{V}(ortho_region))
44-
end
45-
46-
function _TreeTensorNetwork(tn::ITensorNetwork)
47-
return _TreeTensorNetwork(tn, vertices(tn))
4827
end
4928

5029
"""
@@ -61,31 +40,27 @@ Throws an error if the underlying graph of `tn` is not a tree.
6140
# Example
6241
6342
```jldoctest
64-
julia> using NamedGraphs.NamedGraphGenerators: named_comb_tree
65-
66-
julia> using NamedGraphs: NamedGraph
43+
julia> using ITensors: Index, ITensor
6744
6845
julia> using Graphs: vertices
6946
70-
julia> using ITensors: ITensor
47+
julia> i, j, k = Index(2, "i"), Index(2, "j"), Index(2, "k");
7148
72-
julia> g = named_comb_tree((2, 2));
49+
julia> itn = ITensorNetwork([ITensor(i, j), ITensor(j, k)]);
7350
74-
julia> s = siteinds("S=1/2", g);
75-
76-
julia> tensors = Dict(v => ITensor(s[v]...) for v in vertices(g));
77-
78-
julia> itn = ITensorNetwork(tensors);
79-
80-
julia> ttn_state = TreeTensorNetwork(itn; ortho_region = [first(vertices(itn))]);
51+
julia> ttn = TreeTensorNetwork(itn; ortho_region = [first(vertices(itn))]);
8152
8253
```
8354
8455
See also: [`ITensorNetwork`](@ref), [`orthogonalize`](@ref).
8556
"""
86-
function TreeTensorNetwork(tn::ITensorNetwork; ortho_region = vertices(tn))
87-
return _TreeTensorNetwork(tn, ortho_region)
57+
function TreeTensorNetwork(tn::ITensorNetwork{V}; ortho_region = vertices(tn)) where {V}
58+
@assert is_tree(tn)
59+
return TreeTensorNetwork{V}(
60+
tn.graph, tn.vertex_data, tn.ind_to_vertices, Indices{V}(ortho_region)
61+
)
8862
end
63+
8964
function TreeTensorNetwork{V}(tn::ITensorNetwork) where {V}
9065
return TreeTensorNetwork(ITensorNetwork{V}(tn))
9166
end
@@ -132,24 +107,19 @@ function Graphs.rem_vertex!(tn::TTN, v)
132107
return tn
133108
end
134109

135-
function Base.copy(tn::TTN)
136-
V = vertextype(tn)
137-
return _TreeTensorNetwork(
110+
function Base.copy(tn::TTN{V}) where {V}
111+
return TreeTensorNetwork{V}(
138112
copy(tn.graph),
139113
map(copy, tn.vertex_data),
140114
Dict{Index, Set{V}}(i => copy(vs) for (i, vs) in tn.ind_to_vertices),
141115
copy(tn.ortho_region)
142116
)
143117
end
144118

145-
#
146-
# Constructor
147-
#
148-
149119
# set_ortho_region: low-level update of the ortho_region metadata only,
150120
# without any gauge transformations. To move the orthogonality center use orthogonalize.
151-
function set_ortho_region(tn::TTN, ortho_region)
152-
return _TreeTensorNetwork(
153-
tn.graph, tn.vertex_data, tn.ind_to_vertices, Indices{vertextype(tn)}(ortho_region)
121+
function set_ortho_region(tn::TTN{V}, ortho_region) where {V}
122+
return TreeTensorNetwork{V}(
123+
tn.graph, tn.vertex_data, tn.ind_to_vertices, Indices{V}(ortho_region)
154124
)
155125
end

test/utils.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# inside its gensym module.
55

66
using DataGraphs: underlying_graph, vertex_data
7-
using Dictionaries: Dictionary
7+
using Dictionaries: Indices
88
using Graphs: AbstractGraph, dst, edges, src, vertices
99
using ITensorNetworks: ITensorNetwork, IndsNetwork
1010
using ITensors.NDTensors: dim
@@ -36,15 +36,15 @@ function random_tensornetwork(
3636
g = NamedGraph(graph)
3737
links = Dict(e => Index(link_space, "Link") for e in edges(g))
3838
links = merge(links, Dict(reverse(e) => links[e] for e in edges(g)))
39-
# Use a `Dictionary` (insertion-ordered) so the constructed
40-
# `ITensorNetwork`'s vertex / edge order tracks `vertices(g)`.
41-
vs = collect(vertices(g))
42-
ts = map(vs) do v
39+
# `Indices`-keyed `map` returns a `Dictionary` (insertion-ordered),
40+
# so the constructed `ITensorNetwork`'s vertex / edge order tracks
41+
# `vertices(g)`.
42+
ts = map(Indices(vertices(g))) do v
4343
link_v = [links[e] for e in incident_edges(g, v)]
4444
inds_v = [siteinds[v]; link_v]
4545
return itensor(randn(rng, eltype, dim.(inds_v)...), inds_v)
4646
end
47-
return ITensorNetwork(Dictionary(vs, ts))
47+
return ITensorNetwork(ts)
4848
end
4949

5050
# `IndsNetwork`: extract site inds (`Index[]` where unassigned).
@@ -108,13 +108,12 @@ function productstate(elt::Type, state::Function, s::IndsNetwork)
108108
return productstate(elt, Dict(v => state(v) for v in vertices(s)), s)
109109
end
110110
function productstate(elt::Type, state, s::IndsNetwork)
111-
vs = collect(vertices(s))
112-
ts = map(vs) do v
111+
ts = map(Indices(vertices(s))) do v
113112
site_v = isassigned(vertex_data(s), v) ? s[v] : Index[]
114113
t = ITensors.state(state[v], only(site_v))
115114
return ITensors.convert_eltype(elt, t)
116115
end
117-
tn = ITensorNetwork(Dictionary(vs, ts))
116+
tn = ITensorNetwork(ts)
118117
for e in edges(s)
119118
_add_edge!(elt, tn, e)
120119
end

0 commit comments

Comments
 (0)