@@ -127,24 +127,24 @@ bool does_intersect(const std::vector<int>& vec1,
127127}
128128
129129void changed_subgraphs (const Graph& graph, const Graph& subgraph_from,
130- Graph& new_graph, Tensor& out ,
131- const RuntimeOptions& options) {
130+ const std::shared_ptr<Layer>& layer_to, Graph& new_graph ,
131+ Tensor& out, const RuntimeOptions& options) {
132132 graph.clone (new_graph, out, options);
133133 std::vector<std::vector<int >> subs = find_subgraphs (graph, subgraph_from);
134134 std::vector<std::vector<int >> subs_c = subs;
135135 std::vector<bool > sub_used (subs.size (), true );
136136 std::vector<int > roots;
137- std::vector<int > leafs ;
137+ std::vector<int > leaves ;
138138 std::vector<int > roots_inps_final;
139- std::vector<int > leafs_outs_final ;
139+ std::vector<int > leaves_outs_final ;
140140 size_t amount_connected;
141141 size_t amount_connected_s;
142142 for (int v = 0 ; v < subgraph_from.getLayersCount (); v++) {
143143 if (is_root (subgraph_from, v)) {
144144 roots.push_back (v);
145145 }
146146 if (is_leaf (subgraph_from, v)) {
147- leafs .push_back (v);
147+ leaves .push_back (v);
148148 }
149149 }
150150 for (size_t i = 0 ; i < subs.size (); i++) {
@@ -160,10 +160,10 @@ void changed_subgraphs(const Graph& graph, const Graph& subgraph_from,
160160 sub_used[i] = false ;
161161 continue ;
162162 }
163- std::shared_ptr<Layer> layer = std::make_shared<EWLayer>( " relu " );
163+ std::shared_ptr<Layer> layer = layer_based_shared_copy (layer_to, options );
164164 std::vector<bool > is_root_special (roots.size (), false );
165165 roots_inps_final.clear ();
166- leafs_outs_final .clear ();
166+ leaves_outs_final .clear ();
167167 for (size_t j = 0 ; j < roots.size (); j++) {
168168 std::vector<int > root_inps = new_graph.getInLayers (subs[i][roots[j]]);
169169 // want subgraph -> single node
@@ -188,14 +188,14 @@ void changed_subgraphs(const Graph& graph, const Graph& subgraph_from,
188188 }
189189 }
190190 }
191- for (int leaf : leafs ) {
191+ for (int leaf : leaves ) {
192192 amount_connected = new_graph.getOutputsSize (subs[i][leaf]);
193193 for (size_t k = 0 ; k < amount_connected; k++) {
194194 int id = new_graph.getOutLayers (subs[i][leaf])[k];
195195 auto it =
196- std::find (leafs_outs_final .begin (), leafs_outs_final .end (), id);
197- if (it == leafs_outs_final .end ()) {
198- leafs_outs_final .push_back (id);
196+ std::find (leaves_outs_final .begin (), leaves_outs_final .end (), id);
197+ if (it == leaves_outs_final .end ()) {
198+ leaves_outs_final .push_back (id);
199199 }
200200 }
201201 }
@@ -211,8 +211,8 @@ void changed_subgraphs(const Graph& graph, const Graph& subgraph_from,
211211 roots_inps_final.begin (), [&](int elem) {
212212 return elem > subs[i][j] ? elem - 1 : elem;
213213 });
214- std::transform (leafs_outs_final .begin (), leafs_outs_final .end (),
215- leafs_outs_final .begin (), [&](int elem) {
214+ std::transform (leaves_outs_final .begin (), leaves_outs_final .end (),
215+ leaves_outs_final .begin (), [&](int elem) {
216216 return elem > subs[i][j] ? elem - 1 : elem;
217217 });
218218 }
@@ -223,7 +223,7 @@ void changed_subgraphs(const Graph& graph, const Graph& subgraph_from,
223223 if (roots_inps_final.empty ()) {
224224 new_graph.addSingleLayer (layer);
225225 }
226- for (int j : leafs_outs_final ) {
226+ for (int j : leaves_outs_final ) {
227227 new_graph.makeConnection (layer, new_graph.getLayerFromID (j));
228228 }
229229 }
@@ -236,44 +236,44 @@ void changed_subgraphs(const Graph& graph, const Graph& subgraph_from,
236236 std::vector<std::vector<int >> subs = find_subgraphs (graph, subgraph_from);
237237 std::vector<std::vector<int >> subs_c = subs;
238238 std::vector<bool > sub_used (subs.size (), true );
239- std::vector<int > roots ;
240- std::vector<int > leafs ;
241- std::vector<int > roots2 ;
242- std::vector<int > leafs2 ;
239+ std::vector<int > roots_from ;
240+ std::vector<int > leaves_from ;
241+ std::vector<int > roots_to ;
242+ std::vector<int > leaves_to ;
243243 std::vector<std::vector<int >> roots_inps_final;
244- std::vector<std::vector<int >> leafs_outs_final ;
244+ std::vector<std::vector<int >> leaves_outs_final ;
245245 size_t amount_connected;
246246 size_t amount_connected_s;
247247 for (int v = 0 ; v < subgraph_from.getLayersCount (); v++) {
248248 if (is_root (subgraph_from, v)) {
249- roots .push_back (v);
249+ roots_from .push_back (v);
250250 }
251251 if (is_leaf (subgraph_from, v)) {
252- leafs .push_back (v);
252+ leaves_from .push_back (v);
253253 }
254254 }
255255 for (int v = 0 ; v < subgraph_to.getLayersCount (); v++) {
256256 if (is_root (subgraph_to, v)) {
257- roots2 .push_back (v);
257+ roots_to .push_back (v);
258258 }
259259 if (is_leaf (subgraph_to, v)) {
260- leafs2 .push_back (v);
260+ leaves_to .push_back (v);
261261 }
262262 }
263- if (roots2 .size () != roots .size ()) {
263+ if (roots_to .size () != roots_from .size ()) {
264264 throw std::invalid_argument (
265265 " Subgraph_to and Subgraph_from roots amounts aren't same." );
266266 }
267- if (leafs2 .size () != leafs .size ()) {
267+ if (leaves_to .size () != leaves_from .size ()) {
268268 throw std::invalid_argument (
269- " Subgraph_to and Subgraph_from leafs amounts aren't same." );
269+ " Subgraph_to and Subgraph_from leaves amounts aren't same." );
270270 }
271- order.fill_empty (roots .size (), leafs .size ());
272- if (order.in_order .size () != roots .size ()) {
271+ order.fill_empty (roots_from .size (), leaves_from .size ());
272+ if (order.in_order .size () != roots_from .size ()) {
273273 throw std::invalid_argument (" Order for roots isn't complete" );
274274 }
275- if (order.out_order .size () != leafs .size ()) {
276- throw std::invalid_argument (" Order for leafs isn't complete" );
275+ if (order.out_order .size () != leaves_from .size ()) {
276+ throw std::invalid_argument (" Order for leaves isn't complete" );
277277 }
278278 for (size_t i = 0 ; i < subs.size (); i++) {
279279 bool flag = false ;
@@ -288,91 +288,91 @@ void changed_subgraphs(const Graph& graph, const Graph& subgraph_from,
288288 sub_used[i] = false ;
289289 continue ;
290290 }
291- std::vector<bool > is_root_special (roots .size (), false );
291+ std::vector<bool > is_root_special (roots_from .size (), false );
292292 roots_inps_final =
293- std::vector<std::vector<int >>(roots .size (), std::vector<int >());
294- leafs_outs_final =
295- std::vector<std::vector<int >>(leafs .size (), std::vector<int >());
296- for (size_t j = 0 ; j < roots .size (); j++) {
297- roots_inps_final[j] = new_graph.getInLayers (subs[i][roots [j]]);
293+ std::vector<std::vector<int >>(roots_from .size (), std::vector<int >());
294+ leaves_outs_final =
295+ std::vector<std::vector<int >>(leaves_from .size (), std::vector<int >());
296+ for (size_t j = 0 ; j < roots_from .size (); j++) {
297+ roots_inps_final[j] = new_graph.getInLayers (subs[i][roots_from [j]]);
298298 // recognize transformations we can apply with roots
299- amount_connected = new_graph.getOutputsSize (subs[i][roots [j]]);
300- amount_connected_s = subgraph_from.getOutputsSize (roots [j]);
299+ amount_connected = new_graph.getOutputsSize (subs[i][roots_from [j]]);
300+ amount_connected_s = subgraph_from.getOutputsSize (roots_from [j]);
301301 if (amount_connected == amount_connected_s) {
302302 continue ;
303303 }
304304 for (size_t k = 0 ; k < amount_connected; k++) {
305- int id = new_graph.getOutLayers (subs[i][roots [j]])[k];
305+ int id = new_graph.getOutLayers (subs[i][roots_from [j]])[k];
306306 auto it = std::find (subs[i].begin (), subs[i].end (), id);
307307 if (it == subs[i].end ()) {
308308 is_root_special[j] = true ;
309309 }
310310 }
311311 }
312- for (size_t j = 0 ; j < leafs .size (); j++) {
313- amount_connected = new_graph.getOutputsSize (subs[i][leafs [j]]);
312+ for (size_t j = 0 ; j < leaves_from .size (); j++) {
313+ amount_connected = new_graph.getOutputsSize (subs[i][leaves_from [j]]);
314314 for (size_t k = 0 ; k < amount_connected; k++) {
315- int id = new_graph.getOutLayers (subs[i][leafs [j]])[k];
316- leafs_outs_final [j].push_back (id);
315+ int id = new_graph.getOutLayers (subs[i][leaves_from [j]])[k];
316+ leaves_outs_final [j].push_back (id);
317317 }
318318 }
319319 for (size_t j = 0 ; j < subs[i].size (); j++) {
320- auto it = std::find (roots .begin (), roots .end (), j);
321- size_t index_for_root = std::distance (roots .begin (), it);
320+ auto it = std::find (roots_from .begin (), roots_from .end (), j);
321+ size_t index_for_root = std::distance (roots_from .begin (), it);
322322 // remove all nodes that isn't special roots
323- if (it == roots .end () ||
324- (it != roots .end () && !is_root_special[index_for_root])) {
323+ if (it == roots_from .end () ||
324+ (it != roots_from .end () && !is_root_special[index_for_root])) {
325325 new_graph.removeSingleLayer (subs[i][j]);
326326 change_ids (subs, subs[i][j]);
327327 for (auto & k : roots_inps_final) {
328328 std::transform (k.begin (), k.end (), k.begin (), [&](int elem) {
329329 return elem > subs[i][j] ? elem - 1 : elem;
330330 });
331331 }
332- for (auto & k : leafs_outs_final ) {
332+ for (auto & k : leaves_outs_final ) {
333333 std::transform (k.begin (), k.end (), k.begin (), [&](int elem) {
334334 return elem > subs[i][j] ? elem - 1 : elem;
335335 });
336336 }
337337 }
338338 }
339- std::vector<int > roots2_c = roots2 ;
340- std::vector<int > leafs2_c = leafs2 ;
339+ std::vector<int > roots_to_c = roots_to ;
340+ std::vector<int > leaves_to_c = leaves_to ;
341341 std::vector<std::shared_ptr<Layer>> layers;
342342 for (int j = 0 ; j < subgraph_to.getLayersCount (); j++) {
343343 std::shared_ptr<Layer> layer =
344344 layer_based_shared_copy (subgraph_to.getLayerFromID (j), options);
345345 layers.push_back (layer);
346346 new_graph.addSingleLayer (layer);
347- auto it = std::find (roots2_c .begin (), roots2_c .end (), j);
348- if (it != roots2_c .end ()) {
349- size_t index_for_root = std::distance (roots2_c .begin (), it);
350- roots2 [index_for_root] = layer->getID ();
347+ auto it = std::find (roots_to_c .begin (), roots_to_c .end (), j);
348+ if (it != roots_to_c .end ()) {
349+ size_t index_for_root = std::distance (roots_to_c .begin (), it);
350+ roots_to [index_for_root] = layer->getID ();
351351 }
352- it = std::find (leafs2_c .begin (), leafs2_c .end (), j);
353- if (it != leafs2_c .end ()) {
354- size_t index_for_leaf = std::distance (leafs2_c .begin (), it);
355- leafs2 [index_for_leaf] = layer->getID ();
352+ it = std::find (leaves_to_c .begin (), leaves_to_c .end (), j);
353+ if (it != leaves_to_c .end ()) {
354+ size_t index_for_leaf = std::distance (leaves_to_c .begin (), it);
355+ leaves_to [index_for_leaf] = layer->getID ();
356356 }
357357 }
358358 for (int j = 0 ; j < subgraph_to.getLayersCount (); j++) {
359359 std::vector<int > cur_outs = subgraph_to.getOutLayers (j);
360- for (size_t k = 0 ; k < cur_outs. size (); k++ ) {
361- new_graph.makeConnection (layers[j], layers[cur_outs[k] ]);
360+ for (int cur_out : cur_outs) {
361+ new_graph.makeConnection (layers[j], layers[cur_out ]);
362362 }
363363 }
364364 for (size_t j = 0 ; j < roots_inps_final.size (); j++) {
365365 for (size_t k = 0 ; k < roots_inps_final[j].size (); k++) {
366366 new_graph.makeConnection (
367367 new_graph.getLayerFromID (roots_inps_final[j][k]),
368- new_graph.getLayerFromID (roots2 [order.in_order [j]]));
368+ new_graph.getLayerFromID (roots_to [order.in_order [j]]));
369369 }
370370 }
371- for (size_t j = 0 ; j < leafs_outs_final .size (); j++) {
372- for (size_t k = 0 ; k < leafs_outs_final [j].size (); k++) {
371+ for (size_t j = 0 ; j < leaves_outs_final .size (); j++) {
372+ for (size_t k = 0 ; k < leaves_outs_final [j].size (); k++) {
373373 new_graph.makeConnection (
374- new_graph.getLayerFromID (leafs2 [order.out_order [j]]),
375- new_graph.getLayerFromID (leafs_outs_final [j][k]));
374+ new_graph.getLayerFromID (leaves_to [order.out_order [j]]),
375+ new_graph.getLayerFromID (leaves_outs_final [j][k]));
376376 }
377377 }
378378 }
0 commit comments