Skip to content

Commit dd6f645

Browse files
committed
TensorNetwork type now uses new DataGraphs interface
1 parent c43884e commit dd6f645

1 file changed

Lines changed: 35 additions & 15 deletions

File tree

src/tensornetwork.jl

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
using Combinatorics: combinations
22
using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph
3-
using Dictionaries: AbstractDictionary, Indices, dictionary
3+
using Dictionaries: AbstractDictionary, Indices, dictionary, set!, unset!
44
using Graphs: AbstractSimpleGraph, rem_vertex!, rem_edge!
55
using NamedDimsArrays: AbstractNamedDimsArray, dimnames
6-
using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype
6+
using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype, Vertices, parent_graph_indices
77
using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, arrange_edge, vertextype
88
using NamedGraphs.PartitionedGraphs:
99
AbstractPartitionedGraph,
@@ -12,9 +12,13 @@ using NamedGraphs.PartitionedGraphs:
1212
partitioned_vertices,
1313
partitionedgraph,
1414
quotient_graph,
15-
quotient_graph_type
15+
quotient_graph_type,
16+
QuotientVertex,
17+
QuotientVertices,
18+
QuotientVertexVertices,
19+
quotientvertices
1620
using .LazyNamedDimsArrays: lazy, Mul
17-
using DataGraphs: vertex_data_eltype, vertex_data, edge_data
21+
using DataGraphs: vertex_data_eltype, vertex_data, edge_data, get_vertices_data
1822
using DataGraphs.DataGraphsPartitionedGraphsExt
1923

2024
function _TensorNetwork end
@@ -31,18 +35,26 @@ struct TensorNetwork{V, VD, UG <: AbstractGraph{V}, Tensors <: AbstractDictionar
3135
end
3236
end
3337
# This assumes the tensor connectivity matches the graph structure.
34-
function _TensorNetwork(graph::AbstractGraph, tensors)
38+
function TensorNetwork(graph::AbstractGraph, tensors::AbstractDictionary)
3539
return _TensorNetwork(graph, Dictionary(keys(tensors), values(tensors)))
3640
end
3741

3842
function TensorNetwork{V, VD, UG, Tensors}(graph::UG) where {V, VD, UG <: AbstractGraph{V}, Tensors}
3943
return _TensorNetwork(graph, Tensors())
4044
end
4145

42-
DataGraphs.underlying_graph(tn::TensorNetwork) = getfield(tn, :underlying_graph)
43-
DataGraphs.vertex_data(tn::TensorNetwork) = getfield(tn, :tensors)
44-
DataGraphs.edge_data(tn::TensorNetwork) = Dictionary{edgetype(tn), Nothing}()
45-
DataGraphs.vertex_data_eltype(T::Type{<:TensorNetwork}) = eltype(fieldtype(T, :tensors))
46+
# DataGraphs interface
47+
48+
DataGraphs.underlying_graph(tn::TensorNetwork) = tn.underlying_graph
49+
50+
DataGraphs.has_vertex_data(tn::TensorNetwork, v) = haskey(tn.tensors, v)
51+
DataGraphs.has_edge_data(tn::TensorNetwork, e) = false
52+
53+
DataGraphs.get_vertex_data(tn::TensorNetwork, v) = tn.tensors[v]
54+
55+
DataGraphs.set_vertex_data!(tn::TensorNetwork, val, v) = set!(tn.tensors, v, val)
56+
DataGraphs.unset_vertex_data!(tn::TensorNetwork, val, v) = unset!(tn.tensors, v, val)
57+
4658
function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork})
4759
return fieldtype(type, :underlying_graph)
4860
end
@@ -123,27 +135,34 @@ function Graphs.rem_edge!(tn::TensorNetwork, e)
123135
return true
124136
end
125137

126-
function GraphsExtensions.graph_from_vertices(type::Type{<:TensorNetwork}, vertices)
138+
function GraphsExtensions.similar(type::Type{<:TensorNetwork})
127139
DT = fieldtype(type, :tensors)
128140
empty_dict = DT()
129-
return TensorNetwork(similar_graph(underlying_graph_type(type), vertices), empty_dict)
141+
return TensorNetwork(similar_graph(underlying_graph_type(type)), empty_dict)
130142
end
131143

132144
## PartitionedGraphs
133145
function PartitionedGraphs.quotient_graph(tn::TensorNetwork)
134146
ug = quotient_graph(underlying_graph(tn))
135-
return TensorNetwork(ug, vertex_data(QuotientView(tn)))
147+
148+
inds = Indices(parent_graph_indices(QuotientVertices(tn)))
149+
data = map(v -> tn[QuotientVertex(v)], inds)
150+
151+
return TensorNetwork(ug, data)
136152
end
153+
# TODO: This method should not be required with a better interface with a better
154+
# DataGraphsPartitionedGraphsExt interface.
137155
function PartitionedGraphs.quotient_graph_type(type::Type{<:TensorNetwork})
138156
UG = quotient_graph_type(underlying_graph_type(type))
139157
VD = Vector{vertex_data_eltype(type)}
140158
V = vertextype(UG)
141159
return TensorNetwork{V, VD, UG, Dictionary{V, VD}}
142160
end
143161

162+
# Partition the underlying graph of the tensor network; does not affect the data.
144163
function PartitionedGraphs.partitionedgraph(tn::TensorNetwork, parts)
145164
pg = partitionedgraph(underlying_graph(tn), parts)
146-
return TensorNetwork(pg, vertex_data(tn))
165+
return TensorNetwork(pg, copy(vertex_data(tn)))
147166
end
148167

149168
PartitionedGraphs.departition(tn::TensorNetwork) = tn
@@ -153,8 +172,9 @@ function PartitionedGraphs.departition(
153172
return TensorNetwork(departition(underlying_graph(tn)), vertex_data(tn))
154173
end
155174

156-
function DataGraphsPartitionedGraphsExt.to_quotient_vertex_data(::TensorNetwork, data)
157-
return mapreduce(lazy, *, collect(last(data)))
175+
function DataGraphs.get_vertices_data(tn::TensorNetwork, vertex::QuotientVertexVertices)
176+
data = collect(map(v -> tn[v], NamedGraphs.parent_graph_indices(vertex)))
177+
return mapreduce(lazy, *, data)
158178
end
159179

160180
function PartitionedGraphs.quotientview(tn::TensorNetwork)

0 commit comments

Comments
 (0)