Skip to content

Commit a93b2d2

Browse files
committed
test/utils.jl: refactor random_tensornetwork core signature, simplify productstate
random_tensornetwork's core method now takes (rng, eltype, graph::AbstractGraph, siteinds; link_space). The IndsNetwork variant extracts siteinds via a per-vertex dict; the plain-graph variant supplies empty siteinds. Symmetrize the link dictionary with merge(l, Dict(reverse(e) => l[e])) so incident_edges lookup doesn't need a reverse-key fallback. Rename `sites_or_graph` to `sites` in the default-RNG/eltype wrappers. productstate drops the g_full/g_empty pair and iterates s's vertices and edges directly.
1 parent b4f1bd5 commit a93b2d2

1 file changed

Lines changed: 39 additions & 24 deletions

File tree

test/utils.jl

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,43 +15,59 @@ using Random: Random, AbstractRNG
1515

1616
# --- random_tensornetwork ----------------------------------------------------
1717

18-
# At each vertex of `s`'s graph, place an `itensor(randn(rng, eltype, ...), inds_v)`
19-
# whose inds are the site inds at that vertex (from `s[v]`, or empty if unassigned)
20-
# concatenated with one fresh `Index(link_space, "Link")` per incident edge, shared
21-
# with the other endpoint.
18+
# Core: at each vertex of `graph`, place an `itensor(randn(rng, eltype, ...), inds_v)`
19+
# whose inds are `siteinds[v]` concatenated with one fresh `Index(link_space, "Link")`
20+
# per incident edge, shared with the other endpoint. `siteinds` is anything indexable
21+
# by vertex (`keys(siteinds)` matches `vertices(graph)`); use `Index[]` per vertex for
22+
# no site inds.
2223
function random_tensornetwork(
23-
rng::AbstractRNG, eltype::Type, s::IndsNetwork; link_space = 1
24+
rng::AbstractRNG, eltype::Type, graph::AbstractGraph, siteinds; link_space = 1
2425
)
25-
g = NamedGraph(underlying_graph(s))
26+
g = NamedGraph(graph)
2627
links = Dict(e => Index(link_space, "Link") for e in edges(g))
28+
links = merge(links, Dict(reverse(e) => links[e] for e in edges(g)))
2729
tensors = Dict(
2830
map(collect(vertices(g))) do v
29-
site_v = isassigned(vertex_data(s), v) ? s[v] : Index[]
30-
link_v = [
31-
haskey(links, e) ? links[e] : links[reverse(e)]
32-
for e in incident_edges(g, v)
33-
]
34-
inds_v = [site_v; link_v]
31+
link_v = [links[e] for e in incident_edges(g, v)]
32+
inds_v = [siteinds[v]; link_v]
3533
return v => itensor(randn(rng, eltype, dim.(inds_v)...), inds_v)
3634
end
3735
)
3836
return ITensorNetwork(tensors, g)
3937
end
38+
39+
# `IndsNetwork`: extract site inds (`Index[]` where unassigned).
40+
function random_tensornetwork(
41+
rng::AbstractRNG, eltype::Type, s::IndsNetwork; kwargs...
42+
)
43+
siteinds = Dict(
44+
v => isassigned(vertex_data(s), v) ? s[v] : Index[] for v in vertices(s)
45+
)
46+
return random_tensornetwork(rng, eltype, underlying_graph(s), siteinds; kwargs...)
47+
end
48+
49+
# Plain graph: no site inds.
4050
function random_tensornetwork(
4151
rng::AbstractRNG, eltype::Type, g::AbstractGraph; kwargs...
4252
)
43-
return random_tensornetwork(rng, eltype, IndsNetwork(g); kwargs...)
53+
return random_tensornetwork(
54+
rng,
55+
eltype,
56+
g,
57+
Dict(v => Index[] for v in vertices(g));
58+
kwargs...
59+
)
4460
end
4561

4662
# RNG / eltype / both defaults
47-
function random_tensornetwork(rng::AbstractRNG, sites_or_graph; kwargs...)
48-
return random_tensornetwork(rng, Float64, sites_or_graph; kwargs...)
63+
function random_tensornetwork(rng::AbstractRNG, sites; kwargs...)
64+
return random_tensornetwork(rng, Float64, sites; kwargs...)
4965
end
50-
function random_tensornetwork(eltype::Type, sites_or_graph; kwargs...)
51-
return random_tensornetwork(Random.default_rng(), eltype, sites_or_graph; kwargs...)
66+
function random_tensornetwork(eltype::Type, sites; kwargs...)
67+
return random_tensornetwork(Random.default_rng(), eltype, sites; kwargs...)
5268
end
53-
function random_tensornetwork(sites_or_graph; kwargs...)
54-
return random_tensornetwork(Random.default_rng(), Float64, sites_or_graph; kwargs...)
69+
function random_tensornetwork(sites; kwargs...)
70+
return random_tensornetwork(Random.default_rng(), Float64, sites; kwargs...)
5571
end
5672

5773
# --- productstate -------------------------------------------------------------
@@ -80,17 +96,16 @@ function productstate(elt::Type, state::Function, s::IndsNetwork)
8096
return productstate(elt, Dict(v => state(v) for v in vertices(s)), s)
8197
end
8298
function productstate(elt::Type, state, s::IndsNetwork)
83-
g_full = NamedGraph(underlying_graph(s))
84-
g_empty = NamedGraph(collect(vertices(g_full)))
99+
g = NamedGraph(collect(vertices(s)))
85100
tensors = Dict(
86-
map(collect(vertices(g_empty))) do v
101+
map(collect(vertices(s))) do v
87102
site_v = isassigned(vertex_data(s), v) ? s[v] : Index[]
88103
t = ITensors.state(state[v], only(site_v))
89104
return v => ITensors.convert_eltype(elt, t)
90105
end
91106
)
92-
tn = ITensorNetwork(tensors, g_empty)
93-
for e in edges(g_full)
107+
tn = ITensorNetwork(tensors, g)
108+
for e in edges(s)
94109
_add_edge!(elt, tn, e)
95110
end
96111
return tn

0 commit comments

Comments
 (0)