11#pragma once
22#include < algorithm>
33#include < chrono>
4- #include < list>
54#include < memory>
65#include < queue>
76#include < stdexcept>
87#include < string>
98#include < thread>
9+ #include < unordered_map>
1010#include < utility>
1111#include < vector>
1212
@@ -34,7 +34,7 @@ class Graph {
3434 Tensor* outtenres_;
3535 int start_;
3636 int end_;
37- std::list< BranchState> branch_list_ ;
37+ std::unordered_map< int , BranchState> branch_map_ ;
3838 std::vector<std::vector<int >> in_edges_; // next -> prev
3939 std::vector<std::vector<std::pair<int , int >>> split_distribution_;
4040 int count_used_split_distribution_;
@@ -118,13 +118,9 @@ class Graph {
118118 if (!layer) {
119119 throw std::invalid_argument (" Layer cannot be null" );
120120 }
121- bool layer_exists = false ;
122- for (std::shared_ptr<Layer>& existing_layer : layers_) {
123- if (existing_layer == layer) {
124- layer_exists = true ;
125- break ;
126- }
127- }
121+
122+ int id = layer->getID ();
123+ bool layer_exists = (id >= 0 && id < V_ && layers_[id] == layer);
128124
129125 if (!layer_exists) {
130126 layer->setID (V_ );
@@ -144,13 +140,9 @@ class Graph {
144140
145141 void addSingleLayer (const std::shared_ptr<Layer>& layer) {
146142 if (!layer) return ;
147- bool layer_exists = false ;
148- for (const std::shared_ptr<Layer>& existing_layer : layers_) {
149- if (existing_layer == layer) {
150- layer_exists = true ;
151- break ;
152- }
153- }
143+
144+ int id = layer->getID ();
145+ bool layer_exists = (id >= 0 && id < V_ && layers_[id] == layer);
154146
155147 if (!layer_exists) {
156148 layer->setID (V_ );
@@ -296,31 +288,25 @@ class Graph {
296288
297289 for (size_t k = 0 ; k < in_edges_[current_layer].size (); ++k) {
298290 auto target_value = in_edges_[current_layer][k];
299- auto it = std::find_if (branch_list_.rbegin (), branch_list_.rend (),
300- [target_value](const BranchState& s) {
301- return s.ind_layer == target_value;
302- });
303-
304- if (it != branch_list_.rend ()) {
305- for (size_t f = 0 ; f < it->distribution .size (); ++f) {
306- if (it->distribution [f].first == current_layer) {
307- bool last_use = (it->count_used_ten == 1 );
308- auto & src = it->give_for_all [it->distribution [f].second ];
291+ auto it = branch_map_.find (target_value);
292+
293+ if (it != branch_map_.end ()) {
294+ for (size_t f = 0 ; f < it->second .distribution .size (); ++f) {
295+ if (it->second .distribution [f].first == current_layer) {
296+ bool last_use = (it->second .count_used_ten == 1 );
297+ auto & src =
298+ it->second .give_for_all [it->second .distribution [f].second ];
309299 if (last_use) {
310300 inten_.push_back (std::move (src));
311301 } else {
312302 inten_.push_back (src);
313303 }
314304 }
315305 }
316- }
317306
318- if (it != branch_list_.rend ()) {
319- it->count_used_ten --;
320- if (it->count_used_ten < 1 ) {
321- auto rit = std::next (it).base ();
322- it =
323- std::reverse_iterator<decltype (rit)>(branch_list_.erase (rit));
307+ it->second .count_used_ten --;
308+ if (it->second .count_used_ten < 1 ) {
309+ branch_map_.erase (it);
324310 }
325311 }
326312 }
@@ -375,11 +361,11 @@ class Graph {
375361 }
376362 new_branch.distribution = dis;
377363 }
378- branch_list_. push_back ( std::move (new_branch) );
364+ branch_map_[current_layer] = std::move (new_branch);
379365 if (outtenres_ && current_layer == end_ &&
380- !branch_list_. back () .give_for_all .empty () &&
366+ !branch_map_[current_layer] .give_for_all .empty () &&
381367 countinout[current_layer].second == 0 ) {
382- *outtenres_ = std::move (branch_list_. back () .give_for_all [0 ]);
368+ *outtenres_ = std::move (branch_map_[current_layer] .give_for_all [0 ]);
383369 }
384370
385371#ifdef ENABLE_STATISTIC_TIME
0 commit comments