Skip to content

Commit 7c3eb04

Browse files
mtfishmanclaude
andcommitted
Cut down ITN and TTN constructor surface
- Drop ITensorNetwork{V}() empty-arg ctor; inline the three-field call at the single seed site inside ITensorNetwork{V}(tensors). External callers that want an empty network can pass an empty tensor collection (Dict{V, ITensor}() etc.). - Drop the TreeTensorNetwork(::ITensorNetwork; ortho_region=...) and TreeTensorNetwork{V}(::ITensorNetwork) overloads. All construction now routes through TreeTensorNetwork(tensors; ortho_region=nothing), which builds an ITensorNetwork first and then performs the is_tree check. - Define Base.keytype on AbstractITensorNetwork so an AbstractITN can be passed as a tensor collection to ITensorNetwork(tensors). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent dc22cf8 commit 7c3eb04

3 files changed

Lines changed: 20 additions & 30 deletions

File tree

src/abstractitensornetwork.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ Base.copy(tn::AbstractITensorNetwork) = not_implemented()
4545
# whether `vertex_data` is a `Dict`, `Dictionary`, or anything else with
4646
# different default-iteration semantics.
4747
Base.keys(tn::AbstractITensorNetwork) = vertices(tn)
48+
Base.keytype(::Type{<:AbstractITensorNetwork{V}}) where {V} = V
49+
Base.keytype(tn::AbstractITensorNetwork) = keytype(typeof(tn))
4850
Base.values(tn::AbstractITensorNetwork) = (tn[v] for v in vertices(tn))
4951
Base.iterate(tn::AbstractITensorNetwork, args...) = iterate(values(tn), args...)
5052

src/itensornetwork.jl

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,22 +69,15 @@ end
6969
# Constructors
7070
#
7171

72-
# Empty network with no vertices, used as the starting point for the
73-
# tensor-collection constructor below.
74-
function ITensorNetwork{V}() where {V}
75-
return ITensorNetwork{V}(
76-
NamedGraph{V}(), Dictionary{V, ITensor}(), Dict{Index, Set{V}}()
77-
)
78-
end
79-
8072
# Construct by feeding `tensors` through `set_vertex_data!` one vertex
8173
# at a time — this centralizes the reverse-map registration, edge
8274
# inference, and hypergraph check in a single place (the `setindex!`
8375
# code path). Walking `keys(tensors)` in order makes the resulting
8476
# `neighbors(g, v)` / `edges(g)` iteration order deterministic in the
85-
# input order.
77+
# input order. An empty `tensors` (`Dict{V, ITensor}()`, etc.) yields
78+
# an empty network — there is no separate empty-arg constructor.
8679
function ITensorNetwork{V}(tensors) where {V}
87-
tn = ITensorNetwork{V}()
80+
tn = ITensorNetwork{V}(NamedGraph{V}(), Dictionary{V, ITensor}(), Dict{Index, Set{V}}())
8881
for v in keys(tensors)
8982
set_vertex_data!(tn, tensors[v], v)
9083
end

src/treetensornetworks/treetensornetwork.jl

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,17 @@ struct TreeTensorNetwork{V} <: AbstractTreeTensorNetwork{V}
2727
end
2828

2929
"""
30-
TreeTensorNetwork(tn::ITensorNetwork; ortho_region=vertices(tn)) -> TreeTensorNetwork
30+
TreeTensorNetwork(tensors; ortho_region=nothing) -> TreeTensorNetwork
3131
32-
Construct a `TreeTensorNetwork` from an `ITensorNetwork` with tree graph structure.
32+
Construct a `TreeTensorNetwork` from any collection of tensors accepted by
33+
`ITensorNetwork` (e.g. a `Dict`, `Dictionary`, a `Vector{ITensor}`, or another
34+
`AbstractITensorNetwork`). Edges are inferred from shared `Index`es; the
35+
underlying graph must be a tree.
3336
34-
The `ortho_region` keyword specifies which vertices currently form the orthogonality center.
35-
By default all vertices are included, meaning no particular gauge is assumed. To enforce an
36-
actual orthogonal gauge, call [`orthogonalize`](@ref) afterward.
37-
38-
Throws an error if the underlying graph of `tn` is not a tree.
37+
`ortho_region` specifies which vertices currently form the orthogonality
38+
center. The default `nothing` includes all vertices, meaning no particular
39+
gauge is assumed. To enforce an actual orthogonal gauge, call
40+
[`orthogonalize`](@ref) afterward.
3941
4042
# Example
4143
@@ -54,23 +56,16 @@ julia> ttn = TreeTensorNetwork(itn; ortho_region = [first(vertices(itn))]);
5456
5557
See also: [`ITensorNetwork`](@ref), [`orthogonalize`](@ref).
5658
"""
57-
function TreeTensorNetwork(tn::ITensorNetwork{V}; ortho_region = vertices(tn)) where {V}
58-
@assert is_tree(tn)
59+
function TreeTensorNetwork(tensors; ortho_region = nothing)
60+
itn = ITensorNetwork(tensors)
61+
@assert is_tree(itn)
62+
V = vertextype(itn)
63+
region = isnothing(ortho_region) ? vertices(itn) : ortho_region
5964
return TreeTensorNetwork{V}(
60-
tn.graph, tn.vertex_data, tn.ind_to_vertices, Indices{V}(ortho_region)
65+
itn.graph, itn.vertex_data, itn.ind_to_vertices, Indices{V}(region)
6166
)
6267
end
6368

64-
function TreeTensorNetwork{V}(tn::ITensorNetwork) where {V}
65-
return TreeTensorNetwork(ITensorNetwork{V}(tn))
66-
end
67-
68-
# Build a `TreeTensorNetwork` directly from a tensor collection (anything
69-
# accepted by `ITensorNetwork`), saving the caller a wrapping step.
70-
function TreeTensorNetwork(tensors; kwargs...)
71-
return TreeTensorNetwork(ITensorNetwork(tensors); kwargs...)
72-
end
73-
7469
const TTN = TreeTensorNetwork
7570

7671
# Field access

0 commit comments

Comments
 (0)