Skip to content

Commit c6d7e64

Browse files
Upgrade ITensorNetworks to use DataGraphs v0.4.0 and NamedGraphs v0.11.0. (#317)
This package now used the proper `AbstractDataGraph` interface from DataGraphs, and the `similar_graph` interface from `NamedGraphs`. --------- Co-authored-by: Matt Fishman <mtfishman@users.noreply.github.com>
1 parent 9ed5949 commit c6d7e64

19 files changed

Lines changed: 259 additions & 187 deletions

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ITensorNetworks"
22
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
3-
version = "0.19.4"
3+
version = "0.19.5"
44
authors = ["Matthew Fishman <mfishman@flatironinstitute.org>, Joseph Tindall <jtindall@flatironinstitute.org> and contributors"]
55

66
[workspace]
@@ -54,7 +54,7 @@ Adapt = "4"
5454
Combinatorics = "1"
5555
Compat = "3, 4"
5656
ConstructionBase = "1.6"
57-
DataGraphs = "0.2.13"
57+
DataGraphs = "0.4"
5858
Dictionaries = "0.4"
5959
Distributions = "0.25.86"
6060
DocStringExtensions = "0.9"
@@ -65,7 +65,7 @@ IterTools = "1.4"
6565
KrylovKit = "0.6, 0.7, 0.8, 0.9, 0.10"
6666
MacroTools = "0.5"
6767
NDTensors = "0.3, 0.4"
68-
NamedGraphs = "0.8.2"
68+
NamedGraphs = "0.11"
6969
OMEinsumContractionOrders = "0.8.3, 0.9, 1"
7070
Observers = "0.2.4"
7171
SerializedElementArrays = "0.1"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,6 @@ ITensorFormatter = "0.2.27"
1919
ITensorNetworks = "0.19"
2020
ITensors = "0.9"
2121
Literate = "2.20.1"
22-
NamedGraphs = "0.8.2"
22+
NamedGraphs = "0.11"
2323
OMEinsumContractionOrders = "1.2.2"
2424
TensorOperations = "5.5"

docs/make.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ DocMeta.setdocmeta!(
1313
quote
1414
using Graphs: dst, edges, src, vertices
1515
using ITensorNetworks
16-
using ITensorNetworks: TreeTensorNetwork, expect, loginner, mps, orthogonalize,
17-
random_mps, random_ttn, siteinds, truncate, ttn
16+
using ITensorNetworks:
17+
TreeTensorNetwork, expect, loginner, mps, orthogonalize, siteinds, truncate, ttn
1818
using ITensors: inner
1919
using LinearAlgebra: norm, normalize
2020
using OMEinsumContractionOrders

src/abstractindsnetwork.jl

Lines changed: 30 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
using .ITensorsExtensions: ITensorsExtensions, promote_indtype
2-
using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, vertex_data
2+
using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph, IsUnderlyingGraph, edge_data,
3+
get_edge_data, get_vertex_data, is_edge_assigned, is_vertex_assigned, map_data,
4+
set_edge_data!, set_vertex_data!, underlying_graph_type, vertex_data
35
using Graphs: Graphs, AbstractEdge
46
using ITensors: ITensors, IndexSet, unioninds, uniqueinds
5-
using NamedGraphs.GraphsExtensions: incident_edges, rename_vertices
7+
using NamedGraphs.GraphsExtensions:
8+
GraphsExtensions, directed_graph, incident_edges, rename_vertices
69
using NamedGraphs: NamedGraphs
710

811
abstract type AbstractIndsNetwork{V, I} <: AbstractDataGraph{V, Vector{I}, Vector{I}} end
@@ -12,59 +15,36 @@ data_graph(graph::AbstractIndsNetwork) = not_implemented()
1215

1316
# Overload if needed
1417
Graphs.is_directed(::Type{<:AbstractIndsNetwork}) = false
18+
GraphsExtensions.directed_graph(is::AbstractIndsNetwork) = directed_graph(data_graph(is))
1519

1620
# AbstractDataGraphs overloads
17-
function DataGraphs.vertex_data(graph::AbstractIndsNetwork, args...)
18-
return vertex_data(data_graph(graph), args...)
21+
DataGraphs.underlying_graph(is::AbstractIndsNetwork) = underlying_graph(data_graph(is))
22+
23+
# TODO: Define a generic fallback for `AbstractDataGraph`?
24+
DataGraphs.edge_data_type(::Type{<:AbstractIndsNetwork{V, I}}) where {V, I} = Vector{I}
25+
DataGraphs.vertex_data_type(::Type{<:AbstractIndsNetwork{V, I}}) where {V, I} = Vector{I}
26+
27+
function DataGraphs.is_vertex_assigned(is::AbstractIndsNetwork, v)
28+
return is_vertex_assigned(data_graph(is), v)
29+
end
30+
31+
function DataGraphs.set_vertex_data!(is::AbstractIndsNetwork, v, data)
32+
return set_vertex_data!(data_graph(is), v, data)
1933
end
20-
function DataGraphs.edge_data(graph::AbstractIndsNetwork, args...)
21-
return edge_data(data_graph(graph), args...)
34+
function DataGraphs.get_vertex_data(is::AbstractIndsNetwork, v)
35+
return get_vertex_data(data_graph(is), v)
2236
end
2337

24-
# TODO: Define a generic fallback for `AbstractDataGraph`?
25-
DataGraphs.edge_data_eltype(::Type{<:AbstractIndsNetwork{V, I}}) where {V, I} = Vector{I}
26-
27-
## TODO: Bring these back.
28-
## function indsnetwork_getindex(is::AbstractIndsNetwork, index)
29-
## return get(data_graph(is), index, indtype(is)[])
30-
## end
31-
##
32-
## function Base.getindex(is::AbstractIndsNetwork, index)
33-
## return indsnetwork_getindex(is, index)
34-
## end
35-
##
36-
## function Base.getindex(is::AbstractIndsNetwork, index::Pair)
37-
## return indsnetwork_getindex(is, index)
38-
## end
39-
##
40-
## function Base.getindex(is::AbstractIndsNetwork, index::AbstractEdge)
41-
## return indsnetwork_getindex(is, index)
42-
## end
43-
##
44-
## function indsnetwork_setindex!(is::AbstractIndsNetwork, value, index)
45-
## data_graph(is)[index] = value
46-
## return is
47-
## end
48-
##
49-
## function Base.setindex!(is::AbstractIndsNetwork, value, index)
50-
## indsnetwork_setindex!(is, value, index)
51-
## return is
52-
## end
53-
##
54-
## function Base.setindex!(is::AbstractIndsNetwork, value, index::Pair)
55-
## indsnetwork_setindex!(is, value, index)
56-
## return is
57-
## end
58-
##
59-
## function Base.setindex!(is::AbstractIndsNetwork, value, index::AbstractEdge)
60-
## indsnetwork_setindex!(is, value, index)
61-
## return is
62-
## end
63-
##
64-
## function Base.setindex!(is::AbstractIndsNetwork, value::Index, index)
65-
## indsnetwork_setindex!(is, value, index)
66-
## return is
67-
## end
38+
function DataGraphs.is_edge_assigned(is::AbstractIndsNetwork, v)
39+
return is_edge_assigned(data_graph(is), v)
40+
end
41+
42+
function DataGraphs.set_edge_data!(is::AbstractIndsNetwork, v, data)
43+
return set_edge_data!(data_graph(is), v, data)
44+
end
45+
function DataGraphs.get_edge_data(is::AbstractIndsNetwork, v)
46+
return get_edge_data(data_graph(is), v)
47+
end
6848

6949
#
7050
# Index access

src/abstractitensornetwork.jl

Lines changed: 95 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using .ITensorsExtensions: ITensorsExtensions, indtype, promote_indtype
22
using Adapt: Adapt, adapt, adapt_structure
3-
using DataGraphs:
4-
DataGraphs, edge_data, underlying_graph, underlying_graph_type, vertex_data
3+
using DataGraphs: DataGraphs, edge_data, get_vertex_data, is_vertex_assigned,
4+
set_vertex_data!, underlying_graph, underlying_graph_type, vertex_data
55
using Dictionaries: Dictionary
66
using Graphs: Graphs, Graph, add_edge!, add_vertex!, bfs_tree, center, dst, edges, edgetype,
77
ne, neighbors, rem_edge!, src, vertices
@@ -13,7 +13,7 @@ using MacroTools: @capture
1313
using NDTensors: NDTensors, Algorithm, dim, scalartype
1414
using NamedGraphs.GraphsExtensions:
1515
directed_graph, incident_edges, rename_vertices, vertextype,
16-
using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree
16+
using NamedGraphs: NamedGraphs, NamedGraph, Vertices, not_implemented, steiner_tree
1717
using SplitApplyCombine: flatten
1818

1919
abstract type AbstractITensorNetwork{V} <: AbstractDataGraph{V, ITensor, ITensor} end
@@ -23,7 +23,7 @@ data_graph_type(::Type{<:AbstractITensorNetwork}) = not_implemented()
2323
data_graph(graph::AbstractITensorNetwork) = not_implemented()
2424

2525
# TODO: Define a generic fallback for `AbstractDataGraph`?
26-
DataGraphs.edge_data_eltype(::Type{<:AbstractITensorNetwork}) = ITensor
26+
DataGraphs.edge_data_type(::Type{<:AbstractITensorNetwork}) = ITensor
2727

2828
# Graphs.jl overloads
2929
function Graphs.weights(graph::AbstractITensorNetwork)
@@ -49,6 +49,7 @@ Base.eltype(tn::AbstractITensorNetwork) = eltype(vertex_data(tn))
4949

5050
# Overload if needed
5151
Graphs.is_directed(::Type{<:AbstractITensorNetwork}) = false
52+
GraphsExtensions.directed_graph(is::AbstractITensorNetwork) = directed_graph(data_graph(is))
5253

5354
# Derived interface, may need to be overloaded
5455
function DataGraphs.underlying_graph_type(G::Type{<:AbstractITensorNetwork})
@@ -59,15 +60,84 @@ function ITensors.datatype(tn::AbstractITensorNetwork)
5960
return mapreduce(v -> datatype(tn[v]), promote_type, vertices(tn))
6061
end
6162

62-
# AbstractDataGraphs overloads
63-
function DataGraphs.vertex_data(graph::AbstractITensorNetwork, args...)
64-
return vertex_data(data_graph(graph), args...)
63+
# TODO: Move to `BaseExtensions` module.
64+
function is_setindex!_expr(expr::Expr)
65+
return is_assignment_expr(expr) && is_getindex_expr(first(expr.args))
6566
end
66-
function DataGraphs.edge_data(graph::AbstractITensorNetwork, args...)
67-
return edge_data(data_graph(graph), args...)
67+
is_setindex!_expr(x) = false
68+
is_getindex_expr(expr::Expr) = (expr.head === :ref)
69+
is_getindex_expr(x) = false
70+
is_assignment_expr(expr::Expr) = (expr.head === :(=))
71+
is_assignment_expr(expr) = false
72+
73+
# TODO: Define this in terms of a function mapping
74+
# preserve_graph_function(::typeof(setindex!)) = setindex!_preserve_graph
75+
# preserve_graph_function(::typeof(map_vertex_data)) = map_vertex_data_preserve_graph
76+
# Also allow annotating codeblocks like `@views`.
77+
macro preserve_graph(expr)
78+
if !is_setindex!_expr(expr)
79+
error(
80+
"preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)"
81+
)
82+
end
83+
@capture(expr, array_[indices__] = value_)
84+
return :(setindex_preserve_graph!($(esc(array)), $(esc(value)), $(esc.(indices)...)))
6885
end
6986

87+
function setindex_preserve_graph!(tn::AbstractITensorNetwork, value, vertex)
88+
data_graph(tn)[vertex] = value
89+
return tn
90+
end
91+
92+
# AbstractDataGraphs overloads
93+
7094
DataGraphs.underlying_graph(tn::AbstractITensorNetwork) = underlying_graph(data_graph(tn))
95+
96+
function DataGraphs.is_vertex_assigned(is::AbstractITensorNetwork, v)
97+
return is_vertex_assigned(data_graph(is), v)
98+
end
99+
100+
function DataGraphs.is_edge_assigned(is::AbstractITensorNetwork, v)
101+
return is_edge_assigned(data_graph(is), v)
102+
end
103+
104+
function DataGraphs.get_vertex_data(is::AbstractITensorNetwork, v)
105+
return get_vertex_data(data_graph(is), v)
106+
end
107+
108+
function DataGraphs.set_vertex_data!(tn::AbstractITensorNetwork, value, v)
109+
# v = to_vertex(tn, index...)
110+
@preserve_graph tn[v] = value
111+
fix_edges!(tn, v)
112+
return tn
113+
end
114+
115+
function DataGraphs.set_vertices_data!(tn::AbstractITensorNetwork, values, vertices)
116+
# v = to_vertex(tn, index...)
117+
for v in vertices
118+
@preserve_graph tn[v] = values[v]
119+
end
120+
for v in vertices
121+
fix_edges!(tn, v)
122+
end
123+
return tn
124+
end
125+
126+
function fix_edges!(tn::AbstractITensorNetwork, v)
127+
for edge in incident_edges(tn, v)
128+
rem_edge!(tn, edge)
129+
end
130+
for vertex in vertices(tn)
131+
if v vertex
132+
edge = v => vertex
133+
if hascommoninds(tn, edge)
134+
add_edge!(tn, edge)
135+
end
136+
end
137+
end
138+
return tn
139+
end
140+
71141
function NamedGraphs.vertex_positions(tn::AbstractITensorNetwork)
72142
return NamedGraphs.vertex_positions(underlying_graph(tn))
73143
end
@@ -119,35 +189,6 @@ end
119189
# Data modification
120190
#
121191

122-
function setindex_preserve_graph!(tn::AbstractITensorNetwork, value, vertex)
123-
data_graph(tn)[vertex] = value
124-
return tn
125-
end
126-
127-
# TODO: Move to `BaseExtensions` module.
128-
function is_setindex!_expr(expr::Expr)
129-
return is_assignment_expr(expr) && is_getindex_expr(first(expr.args))
130-
end
131-
is_setindex!_expr(x) = false
132-
is_getindex_expr(expr::Expr) = (expr.head === :ref)
133-
is_getindex_expr(x) = false
134-
is_assignment_expr(expr::Expr) = (expr.head === :(=))
135-
is_assignment_expr(expr) = false
136-
137-
# TODO: Define this in terms of a function mapping
138-
# preserve_graph_function(::typeof(setindex!)) = setindex!_preserve_graph
139-
# preserve_graph_function(::typeof(map_vertex_data)) = map_vertex_data_preserve_graph
140-
# Also allow annotating codeblocks like `@views`.
141-
macro preserve_graph(expr)
142-
if !is_setindex!_expr(expr)
143-
error(
144-
"preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)"
145-
)
146-
end
147-
@capture(expr, array_[indices__] = value_)
148-
return :(setindex_preserve_graph!($(esc(array)), $(esc(value)), $(esc.(indices)...)))
149-
end
150-
151192
function ITensors.hascommoninds(tn::AbstractITensorNetwork, edge::Pair)
152193
return hascommoninds(tn, edgetype(tn)(edge))
153194
end
@@ -156,23 +197,6 @@ function ITensors.hascommoninds(tn::AbstractITensorNetwork, edge::AbstractEdge)
156197
return hascommoninds(tn[src(edge)], tn[dst(edge)])
157198
end
158199

159-
function Base.setindex!(tn::AbstractITensorNetwork, value, v)
160-
# v = to_vertex(tn, index...)
161-
@preserve_graph tn[v] = value
162-
for edge in incident_edges(tn, v)
163-
rem_edge!(tn, edge)
164-
end
165-
for vertex in vertices(tn)
166-
if v vertex
167-
edge = v => vertex
168-
if hascommoninds(tn, edge)
169-
add_edge!(tn, edge)
170-
end
171-
end
172-
end
173-
return tn
174-
end
175-
176200
# Convenience wrapper
177201
function eachtensor(tn::AbstractITensorNetwork, vertices = vertices(tn))
178202
return map(v -> tn[v], vertices)
@@ -725,8 +749,8 @@ end
725749
function linkinds_combiners(tn::AbstractITensorNetwork; edges = edges(tn))
726750
combiners = DataGraph(
727751
directed_graph(underlying_graph(tn));
728-
vertex_data_eltype = ITensor,
729-
edge_data_eltype = ITensor
752+
vertex_data_type = ITensor,
753+
edge_data_type = ITensor
730754
)
731755
for e in edges
732756
C = combiner(linkinds(tn, e); tags = edge_tag(e))
@@ -739,7 +763,7 @@ end
739763
function combine_linkinds(tn::AbstractITensorNetwork, combiners)
740764
combined_tn = copy(tn)
741765
for e in edges(tn)
742-
if !isempty(linkinds(tn, e)) && haskey(edge_data(combiners), e)
766+
if !isempty(linkinds(tn, e)) && isassigned(combiners, e)
743767
combined_tn[src(e)] = combined_tn[src(e)] * combiners[e]
744768
combined_tn[dst(e)] = combined_tn[dst(e)] * combiners[reverse(e)]
745769
end
@@ -845,7 +869,7 @@ end
845869

846870
function linkdims(tn::AbstractITensorNetwork{V}) where {V}
847871
ld = DataGraph{V}(
848-
copy(underlying_graph(tn)); vertex_data_eltype = Nothing, edge_data_eltype = Int
872+
copy(underlying_graph(tn)); vertex_data_type = Nothing, edge_data_type = Int
849873
)
850874
for e in edges(ld)
851875
ld[e] = linkdim(tn, e)
@@ -995,3 +1019,15 @@ end
9951019
Base.:+(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork) = add(tn1, tn2)
9961020

9971021
ITensors.hasqns(tn::AbstractITensorNetwork) = any(v -> hasqns(tn[v]), vertices(tn))
1022+
1023+
function NamedGraphs.induced_subgraph_from_vertices(
1024+
itn::AbstractITensorNetwork,
1025+
subvertices
1026+
)
1027+
subgraph, vlist = induced_subgraph(underlying_graph(itn), subvertices)
1028+
subitn = similar_graph(itn, subgraph)
1029+
1030+
subitn[Vertices(subvertices)] = vertex_data(itn)
1031+
1032+
return subitn, vlist
1033+
end

0 commit comments

Comments
 (0)