11using Combinatorics: combinations
22using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph
3- using Dictionaries: AbstractDictionary, Indices, dictionary
3+ using Dictionaries: AbstractDictionary, Indices, dictionary, set!, unset!
44using Graphs: AbstractSimpleGraph, rem_vertex!, rem_edge!
55using NamedDimsArrays: AbstractNamedDimsArray, dimnames
6- using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype
6+ using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype, Vertices, parent_graph_indices
77using NamedGraphs. GraphsExtensions: GraphsExtensions, arranged_edges, arrange_edge, vertextype
88using NamedGraphs. PartitionedGraphs:
99 AbstractPartitionedGraph,
@@ -12,9 +12,13 @@ using NamedGraphs.PartitionedGraphs:
1212 partitioned_vertices,
1313 partitionedgraph,
1414 quotient_graph,
15- quotient_graph_type
15+ quotient_graph_type,
16+ QuotientVertex,
17+ QuotientVertices,
18+ QuotientVertexVertices,
19+ quotientvertices
1620using . LazyNamedDimsArrays: lazy, Mul
17- using DataGraphs: vertex_data_eltype, vertex_data, edge_data
21+ using DataGraphs: vertex_data_eltype, vertex_data, edge_data, get_vertices_data
1822using DataGraphs. DataGraphsPartitionedGraphsExt
1923
2024function _TensorNetwork end
@@ -31,18 +35,26 @@ struct TensorNetwork{V, VD, UG <: AbstractGraph{V}, Tensors <: AbstractDictionar
3135 end
3236end
3337# This assumes the tensor connectivity matches the graph structure.
34- function _TensorNetwork (graph:: AbstractGraph , tensors)
38+ function TensorNetwork (graph:: AbstractGraph , tensors:: AbstractDictionary )
3539 return _TensorNetwork (graph, Dictionary (keys (tensors), values (tensors)))
3640end
3741
3842function TensorNetwork {V, VD, UG, Tensors} (graph:: UG ) where {V, VD, UG <: AbstractGraph{V} , Tensors}
3943 return _TensorNetwork (graph, Tensors ())
4044end
4145
42- DataGraphs. underlying_graph (tn:: TensorNetwork ) = getfield (tn, :underlying_graph )
43- DataGraphs. vertex_data (tn:: TensorNetwork ) = getfield (tn, :tensors )
44- DataGraphs. edge_data (tn:: TensorNetwork ) = Dictionary {edgetype(tn), Nothing} ()
45- DataGraphs. vertex_data_eltype (T:: Type{<:TensorNetwork} ) = eltype (fieldtype (T, :tensors ))
46+ # DataGraphs interface
47+
48+ DataGraphs. underlying_graph (tn:: TensorNetwork ) = tn. underlying_graph
49+
50+ DataGraphs. has_vertex_data (tn:: TensorNetwork , v) = haskey (tn. tensors, v)
51+ DataGraphs. has_edge_data (tn:: TensorNetwork , e) = false
52+
53+ DataGraphs. get_vertex_data (tn:: TensorNetwork , v) = tn. tensors[v]
54+
55+ DataGraphs. set_vertex_data! (tn:: TensorNetwork , val, v) = set! (tn. tensors, v, val)
56+ DataGraphs. unset_vertex_data! (tn:: TensorNetwork , val, v) = unset! (tn. tensors, v, val)
57+
4658function DataGraphs. underlying_graph_type (type:: Type{<:TensorNetwork} )
4759 return fieldtype (type, :underlying_graph )
4860end
@@ -123,27 +135,34 @@ function Graphs.rem_edge!(tn::TensorNetwork, e)
123135 return true
124136end
125137
126- function GraphsExtensions. graph_from_vertices (type:: Type{<:TensorNetwork} , vertices )
138+ function GraphsExtensions. similar (type:: Type{<:TensorNetwork} )
127139 DT = fieldtype (type, :tensors )
128140 empty_dict = DT ()
129- return TensorNetwork (similar_graph (underlying_graph_type (type), vertices ), empty_dict)
141+ return TensorNetwork (similar_graph (underlying_graph_type (type)), empty_dict)
130142end
131143
132144# # PartitionedGraphs
133145function PartitionedGraphs. quotient_graph (tn:: TensorNetwork )
134146 ug = quotient_graph (underlying_graph (tn))
135- return TensorNetwork (ug, vertex_data (QuotientView (tn)))
147+
148+ inds = Indices (parent_graph_indices (QuotientVertices (tn)))
149+ data = map (v -> tn[QuotientVertex (v)], inds)
150+
151+ return TensorNetwork (ug, data)
136152end
153+ # TODO : This method should not be required with a better interface with a better
154+ # DataGraphsPartitionedGraphsExt interface.
137155function PartitionedGraphs. quotient_graph_type (type:: Type{<:TensorNetwork} )
138156 UG = quotient_graph_type (underlying_graph_type (type))
139157 VD = Vector{vertex_data_eltype (type)}
140158 V = vertextype (UG)
141159 return TensorNetwork{V, VD, UG, Dictionary{V, VD}}
142160end
143161
162+ # Partition the underlying graph of the tensor network; does not affect the data.
144163function PartitionedGraphs. partitionedgraph (tn:: TensorNetwork , parts)
145164 pg = partitionedgraph (underlying_graph (tn), parts)
146- return TensorNetwork (pg, vertex_data (tn))
165+ return TensorNetwork (pg, copy ( vertex_data (tn) ))
147166end
148167
149168PartitionedGraphs. departition (tn:: TensorNetwork ) = tn
@@ -153,8 +172,9 @@ function PartitionedGraphs.departition(
153172 return TensorNetwork (departition (underlying_graph (tn)), vertex_data (tn))
154173end
155174
156- function DataGraphsPartitionedGraphsExt. to_quotient_vertex_data (:: TensorNetwork , data)
157- return mapreduce (lazy, * , collect (last (data)))
175+ function DataGraphs. get_vertices_data (tn:: TensorNetwork , vertex:: QuotientVertexVertices )
176+ data = collect (map (v -> tn[v], NamedGraphs. parent_graph_indices (vertex)))
177+ return mapreduce (lazy, * , data)
158178end
159179
160180function PartitionedGraphs. quotientview (tn:: TensorNetwork )
0 commit comments