Skip to content

Commit 94121f4

Browse files
author
NeiroYT
committed
Upd
1 parent 1966fe8 commit 94121f4

5 files changed

Lines changed: 902 additions & 142 deletions

File tree

include/graph/graph.hpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,20 +77,6 @@ class Graph {
7777
split_distribution_ = std::move(split_dist);
7878
}
7979

80-
[[nodiscard]] int getVertexValue(size_t layerID) const {
81-
if (layerID >= arrayV_.size()) {
82-
throw std::invalid_argument("ArrayV does not contain this ID.");
83-
}
84-
return arrayV_[layerID];
85-
}
86-
87-
[[nodiscard]] int getEdgeValue(size_t pos) const {
88-
if (pos >= arrayE_.size()) {
89-
throw std::invalid_argument("ArrayE does not contain this.");
90-
}
91-
return arrayE_[pos];
92-
}
93-
9480
[[nodiscard]] size_t getInputsSize(size_t layerID) const {
9581
if (layerID >= in_edges_.size()) {
9682
throw std::invalid_argument("Input edges array do not contain this ID.");
@@ -105,6 +91,21 @@ class Graph {
10591
return in_edges_[layerID];
10692
}
10793

94+
[[nodiscard]] size_t getOutputsSize(size_t layerID) const {
95+
if (layerID >= layers_.size()) {
96+
throw std::invalid_argument("Layers array do not contain this ID.");
97+
}
98+
return arrayV_[layerID + 1] - arrayV_[layerID];
99+
}
100+
101+
[[nodiscard]] std::vector<int> getOutLayers(size_t layerID) const {
102+
if (layerID >= layers_.size()) {
103+
throw std::invalid_argument("Input edges array do not contain this ID.");
104+
}
105+
return std::vector<int>(arrayE_.begin() + arrayV_[layerID],
106+
arrayE_.begin() + arrayV_[layerID + 1]);
107+
}
108+
108109
[[nodiscard]] int getLayersCount() const { return V_; }
109110

110111
[[nodiscard]] std::shared_ptr<Layer> getLayerFromID(size_t layerID) const {
@@ -231,8 +232,9 @@ class Graph {
231232
}
232233
// remove outputs
233234
int amount_connected = arrayV_[id + 1] - arrayV_[id];
235+
std::vector<int> arrayE_copy = arrayE_;
234236
for (int i = 0; i < amount_connected; i++) {
235-
removeConnection(id, arrayE_[arrayV_[id] + i]);
237+
removeConnection(id, arrayE_copy[arrayV_[id] + i]);
236238
}
237239
// remove vertex
238240
in_edges_.erase(in_edges_.begin() + id);

include/graph_transformations/graph_transformations.hpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,26 @@
22
#include <vector>
33

44
#include "graph/graph.hpp"
5+
#include "layers/EWLayer.hpp"
56
#include "layers/Layer.hpp"
67

78
namespace it_lab_ai {
9+
10+
struct IOOrder {
11+
std::vector<int> in_order;
12+
std::vector<int> out_order;
13+
void fill_empty(size_t in_size, size_t out_size) {
14+
if (in_order.empty()) {
15+
in_order.resize(in_size);
16+
std::iota(in_order.begin(), in_order.end(), 0);
17+
}
18+
if (out_order.empty()) {
19+
out_order.resize(out_size);
20+
std::iota(out_order.begin(), out_order.end(), 0);
21+
}
22+
}
23+
};
24+
825
std::vector<std::vector<int>> find_subgraphs(const Graph& graph,
926
const Graph& subgraph);
1027
bool has_edge(const Graph& graph, int id_from, int id_to);
@@ -13,4 +30,14 @@ bool is_leaf(const Graph& graph, int id);
1330
bool run_search(const Graph& graph, const Graph& subgraph,
1431
std::vector<int>& assignments,
1532
std::vector<std::vector<int>>& results);
33+
34+
void change_ids(std::vector<std::vector<int>>& vec, int id);
35+
bool does_intersect(const std::vector<int>& vec1, const std::vector<int>& vec2);
36+
void changed_subgraphs(const Graph& graph, const Graph& subgraph_from,
37+
Graph& new_graph, Tensor& out,
38+
const RuntimeOptions& options = RuntimeOptions());
39+
void changed_subgraphs(const Graph& graph, const Graph& subgraph_from,
40+
const Graph& subgraph_to, Graph& new_graph, Tensor& out,
41+
const RuntimeOptions& options = RuntimeOptions(),
42+
IOOrder order = IOOrder());
1643
} // namespace it_lab_ai
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
file(GLOB_RECURSE graphT_src *.cpp)
22
add_library(graphT_lib STATIC "${GRAPHT_HEADERS}" "${graphT_src}")
3-
target_link_libraries(graphT_lib PUBLIC TBB_unified)
3+
target_link_libraries(graphT_lib PUBLIC graph_lib TBB_unified)
44
add_dependencies(graphT_lib kokkos_external)

0 commit comments

Comments
 (0)