@@ -24,12 +24,11 @@ bool Node::operator>(const Node& other) const {
2424
2525double TesseractDecoder::get_detcost (size_t d,
2626 const std::vector<char >& blocked_errs,
27- const std::vector<size_t >& det_counts,
28- const std::vector<char >& dets) const {
27+ const std::vector<size_t >& det_counts) const {
2928 double min_cost = INF;
3029 for (size_t ei : d2e[d]) {
3130 if (!blocked_errs[ei]) {
32- double ecost = ( errors[ei].likelihood_cost ) / det_counts[ei];
31+ double ecost = errors[ei].likelihood_cost / det_counts[ei];
3332 min_cost = std::min (min_cost, ecost);
3433 assert (det_counts[ei]);
3534 }
@@ -47,7 +46,7 @@ TesseractDecoder::TesseractDecoder(TesseractConfig config_) : config(config_) {
4746 assert (config.det_orders [i].size () == config.dem .count_detectors ());
4847 }
4948 }
50- assert (this -> config .det_orders .size ());
49+ assert (config.det_orders .size ());
5150 errors = get_errors_from_dem (config.dem .flattened ());
5251 if (config.verbose ) {
5352 for (auto & error : errors) {
@@ -121,7 +120,7 @@ void TesseractDecoder::decode_to_errors(
121120 size_t det_order = beam % config.det_orders .size ();
122121 decode_to_errors (detections, det_order);
123122 double this_cost = cost_from_errors (predicted_errors_buffer);
124- if (!low_confidence_flag and this_cost < best_cost) {
123+ if (!low_confidence_flag && this_cost < best_cost) {
125124 best_errors = predicted_errors_buffer;
126125 best_cost = this_cost;
127126 }
@@ -138,7 +137,7 @@ void TesseractDecoder::decode_to_errors(
138137 ++det_order) {
139138 decode_to_errors (detections, det_order);
140139 double this_cost = cost_from_errors (predicted_errors_buffer);
141- if (!low_confidence_flag and this_cost < best_cost) {
140+ if (!low_confidence_flag && this_cost < best_cost) {
142141 best_errors = predicted_errors_buffer;
143142 best_cost = this_cost;
144143 }
@@ -154,7 +153,7 @@ void TesseractDecoder::decode_to_errors(
154153 }
155154 config.det_beam = max_det_beam;
156155 predicted_errors_buffer = best_errors;
157- low_confidence_flag = ( best_cost == std::numeric_limits<double >::max () );
156+ low_confidence_flag = best_cost == std::numeric_limits<double >::max ();
158157}
159158
160159bool QNode::operator >(const QNode& other) const {
@@ -184,20 +183,16 @@ void TesseractDecoder::to_node(const QNode& qnode,
184183 // Reconstruct the blocked_errs
185184 for (size_t oei : d2e[min_det]) {
186185 node.blocked_errs [oei] = true ;
187- if (!config.at_most_two_errors_per_detector and oei == ei) break ;
186+ if (!config.at_most_two_errors_per_detector && oei == ei) break ;
188187 }
189188
190189 // Reconstruct the dets
191190 for (size_t d : edets[ei]) {
192- if (node.dets [d]) {
193- node.dets [d] = false ;
194- if (config.at_most_two_errors_per_detector ) {
195- for (size_t oei : d2e[d]) {
196- node.blocked_errs [oei] = true ;
197- }
191+ node.dets [d] = !node.dets [d];
192+ if (!node.dets [d] && config.at_most_two_errors_per_detector ) {
193+ for (size_t oei : d2e[d]) {
194+ node.blocked_errs [oei] = true ;
198195 }
199- } else {
200- node.dets [d] = true ;
201196 }
202197 }
203198 }
@@ -218,40 +213,37 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
218213 std::unordered_set<std::vector<char >, VectorCharHash>>
219214 discovered_dets;
220215
221- size_t min_num_dets;
222- {
223- std::vector<size_t > errs;
224- std::vector<char > blocked_errs (num_errors, false );
225- std::vector<size_t > det_counts (num_errors, 0 );
216+ size_t min_num_dets = detections.size ();
217+ std::vector<size_t > errs;
218+ std::vector<char > blocked_errs (num_errors, false );
219+ std::vector<size_t > det_counts (num_errors, 0 );
226220
227- for (size_t d = 0 ; d < num_detectors; ++d) {
228- if (!dets[d]) continue ;
229- for (int ei : d2e[d]) {
230- det_counts[ei]++;
231- }
232- }
233- double initial_cost = 0.0 ;
234- for (size_t d = 0 ; d < num_detectors; ++d) {
235- if (!dets[d]) continue ;
236- initial_cost += get_detcost (d, blocked_errs, det_counts, dets);
221+ for (size_t d = 0 ; d < num_detectors; ++d) {
222+ if (!dets[d]) continue ;
223+ for (int ei : d2e[d]) {
224+ ++det_counts[ei];
237225 }
238- if (initial_cost == INF) {
239- low_confidence_flag = true ;
240- return ;
241- }
242- min_num_dets =
243- static_cast <size_t >(std::count (dets.begin (), dets.end (), true ));
244- // pq.push({errs, dets, initial_cost, min_num_dets, blocked_errs});
245- pq.push ({initial_cost, min_num_dets, errs});
246226 }
247- size_t num_pq_pushed = 1 ;
227+ double initial_cost = 0.0 ;
228+ for (size_t d = 0 ; d < num_detectors; ++d) {
229+ if (!dets[d]) continue ;
230+ initial_cost += get_detcost (d, blocked_errs, det_counts);
231+ }
232+ if (initial_cost == INF) {
233+ low_confidence_flag = true ;
234+ return ;
235+ }
236+ // pq.push({errs, dets, initial_cost, min_num_dets, blocked_errs});
237+ pq.push ({initial_cost, min_num_dets, errs});
248238
239+ size_t num_pq_pushed = 1 ;
249240 size_t max_num_dets = min_num_dets + det_beam;
250241 Node node;
251242 std::vector<size_t > next_det_counts;
252243 std::vector<char > next_next_blocked_errs;
253244 std::vector<char > next_dets;
254245 std::vector<size_t > next_errs;
246+
255247 while (!pq.empty ()) {
256248 const QNode qnode = pq.top ();
257249 if (qnode.num_dets > max_num_dets) {
@@ -281,13 +273,12 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
281273 }
282274 // Store the predicted errors into the buffer
283275 predicted_errors_buffer = node.errs ;
284-
285276 return ;
286277 }
287278
288279 if (node.num_dets > max_num_dets) continue ;
289280
290- if (config.no_revisit_dets and
281+ if (config.no_revisit_dets &&
291282 !discovered_dets[node.num_dets ].insert (node.dets ).second ) {
292283 continue ;
293284 }
@@ -337,9 +328,10 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
337328 for (size_t d = 0 ; d < num_detectors; ++d) {
338329 if (!node.dets [d]) continue ;
339330 for (int ei : d2e[d]) {
340- det_counts[ei]++ ;
331+ ++ det_counts[ei];
341332 }
342333 }
334+
343335 // We cache as we recompute the det costs
344336 std::vector<double > det_costs (num_detectors, -1 );
345337 std::vector<char > next_blocked_errs = node.blocked_errs ;
@@ -363,19 +355,14 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
363355 // iteration
364356 if (last_ei != std::numeric_limits<size_t >::max ()) {
365357 for (int d : edets[last_ei]) {
366- if (node.dets [d]) {
367- for (int oei : d2e[d]) {
368- ++next_det_counts[oei];
369- }
370- } else {
371- for (int oei : d2e[d]) {
372- --next_det_counts[oei];
373- }
358+ int fired = node.dets [d] ? 1 : -1 ;
359+ for (int oei : d2e[d]) {
360+ next_det_counts[oei] += fired;
374361 }
375362 }
376363 }
377- last_ei = ei;
378364
365+ last_ei = ei;
379366 next_blocked_errs[ei] = true ;
380367
381368 next_errs = node.errs ;
@@ -385,24 +372,21 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
385372 double next_cost = node.cost + errors[ei].likelihood_cost ;
386373
387374 size_t next_num_dets = node.num_dets ;
388- next_next_blocked_errs = next_blocked_errs;
375+ if (config.at_most_two_errors_per_detector ) {
376+ next_next_blocked_errs = next_blocked_errs;
377+ }
378+
389379 for (int d : edets[ei]) {
390- if (next_dets[d]) {
391- next_dets[d] = false ;
392- --next_num_dets;
393- for (int oei : d2e[d]) {
394- --next_det_counts[oei];
395- }
396- if (config.at_most_two_errors_per_detector ) {
397- for (size_t oei : d2e[d]) {
398- next_next_blocked_errs[oei] = true ;
399- }
400- }
401- } else {
402- next_dets[d] = true ;
403- ++next_num_dets;
404- for (int oei : d2e[d]) {
405- ++next_det_counts[oei];
380+ next_dets[d] = !next_dets[d];
381+ int fired = next_dets[d] ? 1 : -1 ;
382+ next_num_dets += fired;
383+ for (int oei : d2e[d]) {
384+ next_det_counts[oei] += fired;
385+ }
386+
387+ if (!next_dets[d] && config.at_most_two_errors_per_detector ) {
388+ for (size_t oei : d2e[d]) {
389+ next_next_blocked_errs[oei] = true ;
406390 }
407391 }
408392 }
@@ -411,7 +395,7 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
411395 continue ;
412396 }
413397
414- if (config.no_revisit_dets and
398+ if (config.no_revisit_dets &&
415399 discovered_dets[next_num_dets].find (next_dets) !=
416400 discovered_dets[next_num_dets].end ()) {
417401 continue ;
@@ -421,23 +405,22 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
421405 if (node.dets [d]) {
422406 if (det_costs[d] == -1 ) {
423407 det_costs[d] =
424- get_detcost (d, node.blocked_errs , det_counts, node. dets );
408+ get_detcost (d, node.blocked_errs , det_counts);
425409 }
426410 next_cost -= det_costs[d];
427411 } else {
428- next_cost += get_detcost (d, next_next_blocked_errs, next_det_counts,
429- next_dets);
412+ next_cost += get_detcost (d, config.at_most_two_errors_per_detector ? next_next_blocked_errs : next_blocked_errs, next_det_counts);
430413 }
431414 }
432415 for (size_t od : eneighbors[ei]) {
433416 if (!node.dets [od] || !next_dets[od]) continue ;
434417 if (det_costs[od] == -1 ) {
435418 det_costs[od] =
436- get_detcost (od, node.blocked_errs , det_counts, node. dets );
419+ get_detcost (od, node.blocked_errs , det_counts);
437420 }
438421 next_cost -= det_costs[od];
439422 next_cost +=
440- get_detcost (od, next_next_blocked_errs, next_det_counts, next_dets );
423+ get_detcost (od, config. at_most_two_errors_per_detector ? next_next_blocked_errs : next_blocked_errs, next_det_counts );
441424 }
442425
443426 if (next_cost == INF) {
@@ -497,4 +480,4 @@ void TesseractDecoder::decode_shots(
497480 for (size_t i = 0 ; i < shots.size (); ++i) {
498481 obs_predicted[i] = decode (shots[i].hits );
499482 }
500- }
483+ }
0 commit comments