Skip to content

Commit 66b2b24

Browse files
authored
Implement subgraph to node, subgraph to subgraph replacement (#270)
1 parent 156dacc commit 66b2b24

5 files changed

Lines changed: 915 additions & 145 deletions

File tree

include/graph/graph.hpp

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -83,34 +83,38 @@ class Graph {
8383
split_distribution_ = std::move(split_dist);
8484
}
8585

86-
[[nodiscard]] int getVertexValue(size_t layerID) const {
87-
if (layerID >= arrayV_.size()) {
88-
throw std::invalid_argument("ArrayV does not contain this ID.");
89-
}
90-
return arrayV_[layerID];
91-
}
92-
93-
[[nodiscard]] int getEdgeValue(size_t pos) const {
94-
if (pos >= arrayE_.size()) {
95-
throw std::invalid_argument("ArrayE does not contain this.");
96-
}
97-
return arrayE_[pos];
98-
}
99-
10086
[[nodiscard]] size_t getInputsSize(size_t layerID) const {
10187
if (layerID >= in_edges_.size()) {
102-
throw std::invalid_argument("Input edges array do not contain this ID.");
88+
throw std::invalid_argument(
89+
"Input edges array does not contain this ID.");
10390
}
10491
return in_edges_[layerID].size();
10592
}
10693

10794
[[nodiscard]] std::vector<int> getInLayers(size_t layerID) const {
10895
if (layerID >= in_edges_.size()) {
109-
throw std::invalid_argument("Input edges array do not contain this ID.");
96+
throw std::invalid_argument(
97+
"Input edges array does not contain this ID.");
11098
}
11199
return in_edges_[layerID];
112100
}
113101

102+
[[nodiscard]] size_t getOutputsSize(size_t layerID) const {
103+
if (layerID >= layers_.size()) {
104+
throw std::invalid_argument("Layers array does not contain this ID.");
105+
}
106+
return arrayV_[layerID + 1] - arrayV_[layerID];
107+
}
108+
109+
[[nodiscard]] std::vector<int> getOutLayers(size_t layerID) const {
110+
if (layerID >= layers_.size()) {
111+
throw std::invalid_argument(
112+
"Output edges array does not contain this ID.");
113+
}
114+
return std::vector<int>(arrayE_.begin() + arrayV_[layerID],
115+
arrayE_.begin() + arrayV_[layerID + 1]);
116+
}
117+
114118
[[nodiscard]] int getLayersCount() const { return V_; }
115119

116120
[[nodiscard]] std::shared_ptr<Layer> getLayerFromID(size_t layerID) const {
@@ -237,8 +241,9 @@ class Graph {
237241
}
238242
// remove outputs
239243
int amount_connected = arrayV_[id + 1] - arrayV_[id];
244+
std::vector<int> array_e_copy = arrayE_;
240245
for (int i = 0; i < amount_connected; i++) {
241-
removeConnection(id, arrayE_[arrayV_[id] + i]);
246+
removeConnection(id, array_e_copy[arrayV_[id] + i]);
242247
}
243248
// remove vertex
244249
in_edges_.erase(in_edges_.begin() + id);

include/graph_transformations/graph_transformations.hpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,26 @@
22
#include <vector>
33

44
#include "graph/graph.hpp"
5+
#include "layers/EWLayer.hpp"
56
#include "layers/Layer.hpp"
67

78
namespace it_lab_ai {
9+
10+
struct IOOrder {
11+
std::vector<int> in_order;
12+
std::vector<int> out_order;
13+
void fill_empty(size_t in_size, size_t out_size) {
14+
if (in_order.empty()) {
15+
in_order.resize(in_size);
16+
std::iota(in_order.begin(), in_order.end(), 0);
17+
}
18+
if (out_order.empty()) {
19+
out_order.resize(out_size);
20+
std::iota(out_order.begin(), out_order.end(), 0);
21+
}
22+
}
23+
};
24+
825
std::vector<std::vector<int>> find_subgraphs(const Graph& graph,
926
const Graph& subgraph);
1027
bool has_edge(const Graph& graph, int id_from, int id_to);
@@ -13,4 +30,15 @@ bool is_leaf(const Graph& graph, int id);
1330
bool run_search(const Graph& graph, const Graph& subgraph,
1431
std::vector<int>& assignments,
1532
std::vector<std::vector<int>>& results);
33+
34+
void change_ids(std::vector<std::vector<int>>& vec, int id);
35+
bool does_intersect(const std::vector<int>& vec1, const std::vector<int>& vec2);
36+
void changed_subgraphs(const Graph& graph, const Graph& subgraph_from,
37+
const std::shared_ptr<Layer>& layer_to, Graph& new_graph,
38+
Tensor& out,
39+
const RuntimeOptions& options = RuntimeOptions());
40+
void changed_subgraphs(const Graph& graph, const Graph& subgraph_from,
41+
const Graph& subgraph_to, Graph& new_graph, Tensor& out,
42+
const RuntimeOptions& options = RuntimeOptions(),
43+
IOOrder order = IOOrder());
1644
} // namespace it_lab_ai
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
file(GLOB_RECURSE graphT_src *.cpp)
22
add_library(graphT_lib STATIC "${GRAPHT_HEADERS}" "${graphT_src}")
3-
target_link_libraries(graphT_lib PUBLIC TBB_unified)
3+
target_link_libraries(graphT_lib PUBLIC graph_lib TBB_unified)
44
add_dependencies(graphT_lib kokkos_external)

0 commit comments

Comments
 (0)