11using . ITensorsExtensions: ITensorsExtensions, indtype, promote_indtype
22using Adapt: Adapt, adapt, adapt_structure
3- using DataGraphs:
4- DataGraphs, edge_data , underlying_graph, underlying_graph_type, vertex_data
3+ using DataGraphs: DataGraphs, edge_data, get_vertex_data, is_vertex_assigned,
4+ set_vertex_data! , underlying_graph, underlying_graph_type, vertex_data
55using Dictionaries: Dictionary
66using Graphs: Graphs, Graph, add_edge!, add_vertex!, bfs_tree, center, dst, edges, edgetype,
77 ne, neighbors, rem_edge!, src, vertices
@@ -13,7 +13,7 @@ using MacroTools: @capture
1313using NDTensors: NDTensors, Algorithm, dim, scalartype
1414using NamedGraphs. GraphsExtensions:
1515 directed_graph, incident_edges, rename_vertices, vertextype, ⊔
16- using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree
16+ using NamedGraphs: NamedGraphs, NamedGraph, Vertices, not_implemented, steiner_tree
1717using SplitApplyCombine: flatten
1818
1919abstract type AbstractITensorNetwork{V} <: AbstractDataGraph{V, ITensor, ITensor} end
@@ -23,7 +23,7 @@ data_graph_type(::Type{<:AbstractITensorNetwork}) = not_implemented()
2323data_graph (graph:: AbstractITensorNetwork ) = not_implemented ()
2424
2525# TODO : Define a generic fallback for `AbstractDataGraph`?
26- DataGraphs. edge_data_eltype (:: Type{<:AbstractITensorNetwork} ) = ITensor
26+ DataGraphs. edge_data_type (:: Type{<:AbstractITensorNetwork} ) = ITensor
2727
2828# Graphs.jl overloads
2929function Graphs. weights (graph:: AbstractITensorNetwork )
@@ -49,6 +49,7 @@ Base.eltype(tn::AbstractITensorNetwork) = eltype(vertex_data(tn))
4949
5050# Overload if needed
5151Graphs. is_directed (:: Type{<:AbstractITensorNetwork} ) = false
52+ GraphsExtensions. directed_graph (is:: AbstractITensorNetwork ) = directed_graph (data_graph (is))
5253
5354# Derived interface, may need to be overloaded
5455function DataGraphs. underlying_graph_type (G:: Type{<:AbstractITensorNetwork} )
@@ -59,15 +60,84 @@ function ITensors.datatype(tn::AbstractITensorNetwork)
5960 return mapreduce (v -> datatype (tn[v]), promote_type, vertices (tn))
6061end
6162
62- # AbstractDataGraphs overloads
63- function DataGraphs . vertex_data (graph :: AbstractITensorNetwork , args ... )
64- return vertex_data ( data_graph (graph), args ... )
63+ # TODO : Move to `BaseExtensions` module.
64+ function is_setindex!_expr (expr :: Expr )
65+ return is_assignment_expr (expr) && is_getindex_expr ( first (expr . args) )
6566end
66- function DataGraphs. edge_data (graph:: AbstractITensorNetwork , args... )
67- return edge_data (data_graph (graph), args... )
67+ is_setindex!_expr (x) = false
68+ is_getindex_expr (expr:: Expr ) = (expr. head === :ref )
69+ is_getindex_expr (x) = false
70+ is_assignment_expr (expr:: Expr ) = (expr. head === :(= ))
71+ is_assignment_expr (expr) = false
72+
73+ # TODO : Define this in terms of a function mapping
74+ # preserve_graph_function(::typeof(setindex!)) = setindex!_preserve_graph
75+ # preserve_graph_function(::typeof(map_vertex_data)) = map_vertex_data_preserve_graph
76+ # Also allow annotating codeblocks like `@views`.
77+ macro preserve_graph (expr)
78+ if ! is_setindex!_expr (expr)
79+ error (
80+ " preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)"
81+ )
82+ end
83+ @capture (expr, array_[indices__] = value_)
84+ return :(setindex_preserve_graph! ($ (esc (array)), $ (esc (value)), $ (esc .(indices)... )))
6885end
6986
87+ function setindex_preserve_graph! (tn:: AbstractITensorNetwork , value, vertex)
88+ data_graph (tn)[vertex] = value
89+ return tn
90+ end
91+
92+ # AbstractDataGraphs overloads
93+
7094DataGraphs. underlying_graph (tn:: AbstractITensorNetwork ) = underlying_graph (data_graph (tn))
95+
96+ function DataGraphs. is_vertex_assigned (is:: AbstractITensorNetwork , v)
97+ return is_vertex_assigned (data_graph (is), v)
98+ end
99+
100+ function DataGraphs. is_edge_assigned (is:: AbstractITensorNetwork , v)
101+ return is_edge_assigned (data_graph (is), v)
102+ end
103+
104+ function DataGraphs. get_vertex_data (is:: AbstractITensorNetwork , v)
105+ return get_vertex_data (data_graph (is), v)
106+ end
107+
108+ function DataGraphs. set_vertex_data! (tn:: AbstractITensorNetwork , value, v)
109+ # v = to_vertex(tn, index...)
110+ @preserve_graph tn[v] = value
111+ fix_edges! (tn, v)
112+ return tn
113+ end
114+
115+ function DataGraphs. set_vertices_data! (tn:: AbstractITensorNetwork , values, vertices)
116+ # v = to_vertex(tn, index...)
117+ for v in vertices
118+ @preserve_graph tn[v] = values[v]
119+ end
120+ for v in vertices
121+ fix_edges! (tn, v)
122+ end
123+ return tn
124+ end
125+
126+ function fix_edges! (tn:: AbstractITensorNetwork , v)
127+ for edge in incident_edges (tn, v)
128+ rem_edge! (tn, edge)
129+ end
130+ for vertex in vertices (tn)
131+ if v ≠ vertex
132+ edge = v => vertex
133+ if hascommoninds (tn, edge)
134+ add_edge! (tn, edge)
135+ end
136+ end
137+ end
138+ return tn
139+ end
140+
71141function NamedGraphs. vertex_positions (tn:: AbstractITensorNetwork )
72142 return NamedGraphs. vertex_positions (underlying_graph (tn))
73143end
119189# Data modification
120190#
121191
122- function setindex_preserve_graph! (tn:: AbstractITensorNetwork , value, vertex)
123- data_graph (tn)[vertex] = value
124- return tn
125- end
126-
127- # TODO : Move to `BaseExtensions` module.
128- function is_setindex!_expr (expr:: Expr )
129- return is_assignment_expr (expr) && is_getindex_expr (first (expr. args))
130- end
131- is_setindex!_expr (x) = false
132- is_getindex_expr (expr:: Expr ) = (expr. head === :ref )
133- is_getindex_expr (x) = false
134- is_assignment_expr (expr:: Expr ) = (expr. head === :(= ))
135- is_assignment_expr (expr) = false
136-
137- # TODO : Define this in terms of a function mapping
138- # preserve_graph_function(::typeof(setindex!)) = setindex!_preserve_graph
139- # preserve_graph_function(::typeof(map_vertex_data)) = map_vertex_data_preserve_graph
140- # Also allow annotating codeblocks like `@views`.
141- macro preserve_graph (expr)
142- if ! is_setindex!_expr (expr)
143- error (
144- " preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)"
145- )
146- end
147- @capture (expr, array_[indices__] = value_)
148- return :(setindex_preserve_graph! ($ (esc (array)), $ (esc (value)), $ (esc .(indices)... )))
149- end
150-
151192function ITensors. hascommoninds (tn:: AbstractITensorNetwork , edge:: Pair )
152193 return hascommoninds (tn, edgetype (tn)(edge))
153194end
@@ -156,23 +197,6 @@ function ITensors.hascommoninds(tn::AbstractITensorNetwork, edge::AbstractEdge)
156197 return hascommoninds (tn[src (edge)], tn[dst (edge)])
157198end
158199
159- function Base. setindex! (tn:: AbstractITensorNetwork , value, v)
160- # v = to_vertex(tn, index...)
161- @preserve_graph tn[v] = value
162- for edge in incident_edges (tn, v)
163- rem_edge! (tn, edge)
164- end
165- for vertex in vertices (tn)
166- if v ≠ vertex
167- edge = v => vertex
168- if hascommoninds (tn, edge)
169- add_edge! (tn, edge)
170- end
171- end
172- end
173- return tn
174- end
175-
176200# Convenience wrapper
177201function eachtensor (tn:: AbstractITensorNetwork , vertices = vertices (tn))
178202 return map (v -> tn[v], vertices)
725749function linkinds_combiners (tn:: AbstractITensorNetwork ; edges = edges (tn))
726750 combiners = DataGraph (
727751 directed_graph (underlying_graph (tn));
728- vertex_data_eltype = ITensor,
729- edge_data_eltype = ITensor
752+ vertex_data_type = ITensor,
753+ edge_data_type = ITensor
730754 )
731755 for e in edges
732756 C = combiner (linkinds (tn, e); tags = edge_tag (e))
739763function combine_linkinds (tn:: AbstractITensorNetwork , combiners)
740764 combined_tn = copy (tn)
741765 for e in edges (tn)
742- if ! isempty (linkinds (tn, e)) && haskey ( edge_data ( combiners) , e)
766+ if ! isempty (linkinds (tn, e)) && isassigned ( combiners, e)
743767 combined_tn[src (e)] = combined_tn[src (e)] * combiners[e]
744768 combined_tn[dst (e)] = combined_tn[dst (e)] * combiners[reverse (e)]
745769 end
845869
846870function linkdims (tn:: AbstractITensorNetwork{V} ) where {V}
847871 ld = DataGraph {V} (
848- copy (underlying_graph (tn)); vertex_data_eltype = Nothing, edge_data_eltype = Int
872+ copy (underlying_graph (tn)); vertex_data_type = Nothing, edge_data_type = Int
849873 )
850874 for e in edges (ld)
851875 ld[e] = linkdim (tn, e)
@@ -995,3 +1019,15 @@ end
9951019Base.:+ (tn1:: AbstractITensorNetwork , tn2:: AbstractITensorNetwork ) = add (tn1, tn2)
9961020
9971021ITensors. hasqns (tn:: AbstractITensorNetwork ) = any (v -> hasqns (tn[v]), vertices (tn))
1022+
1023+ function NamedGraphs. induced_subgraph_from_vertices (
1024+ itn:: AbstractITensorNetwork ,
1025+ subvertices
1026+ )
1027+ subgraph, vlist = induced_subgraph (underlying_graph (itn), subvertices)
1028+ subitn = similar_graph (itn, subgraph)
1029+
1030+ subitn[Vertices (subvertices)] = vertex_data (itn)
1031+
1032+ return subitn, vlist
1033+ end
0 commit comments