Skip to content

Commit 6d815f8

Browse files
Minor changes
Signed-off-by: Dragana Grbic <draganaurosgrbic@gmail.com> n# modified: src/tesseract.cc
1 parent f1ddcd3 commit 6d815f8

2 files changed

Lines changed: 59 additions & 76 deletions

File tree

src/tesseract.cc

Lines changed: 55 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,11 @@ bool Node::operator>(const Node& other) const {
2424

2525
double TesseractDecoder::get_detcost(size_t d,
2626
const std::vector<char>& blocked_errs,
27-
const std::vector<size_t>& det_counts,
28-
const std::vector<char>& dets) const {
27+
const std::vector<size_t>& det_counts) const {
2928
double min_cost = INF;
3029
for (size_t ei : d2e[d]) {
3130
if (!blocked_errs[ei]) {
32-
double ecost = (errors[ei].likelihood_cost) / det_counts[ei];
31+
double ecost = errors[ei].likelihood_cost / det_counts[ei];
3332
min_cost = std::min(min_cost, ecost);
3433
assert(det_counts[ei]);
3534
}
@@ -46,7 +45,7 @@ TesseractDecoder::TesseractDecoder(TesseractConfig config_) : config(config_) {
4645
assert(config.det_orders[i].size() == config.dem.count_detectors());
4746
}
4847
}
49-
assert(this->config.det_orders.size());
48+
assert(config.det_orders.size());
5049
errors = get_errors_from_dem(config.dem.flattened());
5150
num_detectors = config.dem.count_detectors();
5251
num_errors = config.dem.count_errors();
@@ -112,7 +111,7 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections)
112111
size_t det_order = beam % config.det_orders.size();
113112
decode_to_errors(detections, det_order);
114113
double this_cost = cost_from_errors(predicted_errors_buffer);
115-
if (!low_confidence_flag and this_cost < best_cost) {
114+
if (!low_confidence_flag && this_cost < best_cost) {
116115
best_errors = predicted_errors_buffer;
117116
best_cost = this_cost;
118117
}
@@ -129,7 +128,7 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections)
129128
++det_order) {
130129
decode_to_errors(detections, det_order);
131130
double this_cost = cost_from_errors(predicted_errors_buffer);
132-
if (!low_confidence_flag and this_cost < best_cost) {
131+
if (!low_confidence_flag && this_cost < best_cost) {
133132
best_errors = predicted_errors_buffer;
134133
best_cost = this_cost;
135134
}
@@ -145,7 +144,7 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections)
145144
}
146145
config.det_beam = max_det_beam;
147146
predicted_errors_buffer = best_errors;
148-
low_confidence_flag = (best_cost == std::numeric_limits<double>::max());
147+
low_confidence_flag = best_cost == std::numeric_limits<double>::max();
149148
}
150149

151150
bool QNode::operator>(const QNode& other) const {
@@ -175,20 +174,16 @@ void TesseractDecoder::to_node(const QNode& qnode,
175174
// Reconstruct the blocked_errs
176175
for (size_t oei : d2e[min_det]) {
177176
node.blocked_errs[oei] = true;
178-
if (!config.at_most_two_errors_per_detector and oei == ei) break;
177+
if (!config.at_most_two_errors_per_detector && oei == ei) break;
179178
}
180179

181180
// Reconstruct the dets
182181
for (size_t d : edets[ei]) {
183-
if (node.dets[d]) {
184-
node.dets[d] = false;
185-
if (config.at_most_two_errors_per_detector) {
186-
for (size_t oei : d2e[d]) {
187-
node.blocked_errs[oei] = true;
188-
}
182+
node.dets[d] = !node.dets[d];
183+
if (!node.dets[d] && config.at_most_two_errors_per_detector) {
184+
for (size_t oei : d2e[d]) {
185+
node.blocked_errs[oei] = true;
189186
}
190-
} else {
191-
node.dets[d] = true;
192187
}
193188
}
194189
}
@@ -209,40 +204,37 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
209204
std::unordered_set<std::vector<char>, VectorCharHash>>
210205
discovered_dets;
211206

212-
size_t min_num_dets;
213-
{
214-
std::vector<size_t> errs;
215-
std::vector<char> blocked_errs(num_errors, false);
216-
std::vector<size_t> det_counts(num_errors, 0);
207+
size_t min_num_dets = detections.size();
208+
std::vector<size_t> errs;
209+
std::vector<char> blocked_errs(num_errors, false);
210+
std::vector<size_t> det_counts(num_errors, 0);
217211

218-
for (size_t d = 0; d < num_detectors; ++d) {
219-
if (!dets[d]) continue;
220-
for (int ei : d2e[d]) {
221-
det_counts[ei]++;
222-
}
212+
for (size_t d = 0; d < num_detectors; ++d) {
213+
if (!dets[d]) continue;
214+
for (int ei : d2e[d]) {
215+
++det_counts[ei];
223216
}
224-
double initial_cost = 0.0;
225-
for (size_t d = 0; d < num_detectors; ++d) {
226-
if (!dets[d]) continue;
227-
initial_cost += get_detcost(d, blocked_errs, det_counts, dets);
228-
}
229-
if (initial_cost == INF) {
230-
low_confidence_flag = true;
231-
return;
232-
}
233-
min_num_dets =
234-
static_cast<size_t>(std::count(dets.begin(), dets.end(), true));
235-
// pq.push({errs, dets, initial_cost, min_num_dets, blocked_errs});
236-
pq.push({initial_cost, min_num_dets, errs});
237217
}
238-
size_t num_pq_pushed = 1;
218+
double initial_cost = 0.0;
219+
for (size_t d = 0; d < num_detectors; ++d) {
220+
if (!dets[d]) continue;
221+
initial_cost += get_detcost(d, blocked_errs, det_counts);
222+
}
223+
if (initial_cost == INF) {
224+
low_confidence_flag = true;
225+
return;
226+
}
227+
// pq.push({errs, dets, initial_cost, min_num_dets, blocked_errs});
228+
pq.push({initial_cost, min_num_dets, errs});
239229

230+
size_t num_pq_pushed = 1;
240231
size_t max_num_dets = min_num_dets + det_beam;
241232
Node node;
242233
std::vector<size_t> next_det_counts;
243234
std::vector<char> next_next_blocked_errs;
244235
std::vector<char> next_dets;
245236
std::vector<size_t> next_errs;
237+
246238
while (!pq.empty()) {
247239
const QNode qnode = pq.top();
248240
if (qnode.num_dets > max_num_dets) {
@@ -260,13 +252,12 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
260252
}
261253
// Store the predicted errors into the buffer
262254
predicted_errors_buffer = node.errs;
263-
264255
return;
265256
}
266257

267258
if (node.num_dets > max_num_dets) continue;
268259

269-
if (config.no_revisit_dets and
260+
if (config.no_revisit_dets &&
270261
!discovered_dets[node.num_dets].insert(node.dets).second) {
271262
continue;
272263
}
@@ -308,9 +299,10 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
308299
for (size_t d = 0; d < num_detectors; ++d) {
309300
if (!node.dets[d]) continue;
310301
for (int ei : d2e[d]) {
311-
det_counts[ei]++;
302+
++det_counts[ei];
312303
}
313304
}
305+
314306
// We cache as we recompute the det costs
315307
std::vector<double> det_costs(num_detectors, -1);
316308
std::vector<char> next_blocked_errs = node.blocked_errs;
@@ -334,19 +326,14 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
334326
// iteration
335327
if (last_ei != std::numeric_limits<size_t>::max()) {
336328
for (int d : edets[last_ei]) {
337-
if (node.dets[d]) {
338-
for (int oei : d2e[d]) {
339-
++next_det_counts[oei];
340-
}
341-
} else {
342-
for (int oei : d2e[d]) {
343-
--next_det_counts[oei];
344-
}
329+
int fired = node.dets[d] ? 1 : -1;
330+
for (int oei : d2e[d]) {
331+
next_det_counts[oei] += fired;
345332
}
346333
}
347334
}
348-
last_ei = ei;
349335

336+
last_ei = ei;
350337
next_blocked_errs[ei] = true;
351338

352339
next_errs = node.errs;
@@ -359,23 +346,18 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
359346
if (config.at_most_two_errors_per_detector) {
360347
next_next_blocked_errs = next_blocked_errs;
361348
}
349+
362350
for (int d : edets[ei]) {
363-
if (next_dets[d]) {
364-
next_dets[d] = false;
365-
--next_num_dets;
366-
for (int oei : d2e[d]) {
367-
--next_det_counts[oei];
368-
}
369-
if (config.at_most_two_errors_per_detector) {
370-
for (size_t oei : d2e[d]) {
371-
next_next_blocked_errs[oei] = true;
372-
}
373-
}
374-
} else {
375-
next_dets[d] = true;
376-
++next_num_dets;
377-
for (int oei : d2e[d]) {
378-
++next_det_counts[oei];
351+
next_dets[d] = !next_dets[d];
352+
int fired = next_dets[d] ? 1 : -1;
353+
next_num_dets += fired;
354+
for (int oei : d2e[d]) {
355+
next_det_counts[oei] += fired;
356+
}
357+
358+
if (!next_dets[d] && config.at_most_two_errors_per_detector) {
359+
for (size_t oei : d2e[d]) {
360+
next_next_blocked_errs[oei] = true;
379361
}
380362
}
381363
}
@@ -384,7 +366,7 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
384366
continue;
385367
}
386368

387-
if (config.no_revisit_dets and
369+
if (config.no_revisit_dets &&
388370
discovered_dets[next_num_dets].find(next_dets) !=
389371
discovered_dets[next_num_dets].end()) {
390372
continue;
@@ -394,23 +376,22 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
394376
if (node.dets[d]) {
395377
if (det_costs[d] == -1) {
396378
det_costs[d] =
397-
get_detcost(d, node.blocked_errs, det_counts, node.dets);
379+
get_detcost(d, node.blocked_errs, det_counts);
398380
}
399381
next_cost -= det_costs[d];
400382
} else {
401-
next_cost += get_detcost(d, config.at_most_two_errors_per_detector ? next_next_blocked_errs : next_blocked_errs, next_det_counts,
402-
next_dets);
383+
next_cost += get_detcost(d, config.at_most_two_errors_per_detector ? next_next_blocked_errs : next_blocked_errs, next_det_counts);
403384
}
404385
}
405386
for (size_t od : eneighbors[ei]) {
406387
if (!node.dets[od] || !next_dets[od]) continue;
407388
if (det_costs[od] == -1) {
408389
det_costs[od] =
409-
get_detcost(od, node.blocked_errs, det_counts, node.dets);
390+
get_detcost(od, node.blocked_errs, det_counts);
410391
}
411392
next_cost -= det_costs[od];
412393
next_cost +=
413-
get_detcost(od, config.at_most_two_errors_per_detector ? next_next_blocked_errs : next_blocked_errs, next_det_counts, next_dets);
394+
get_detcost(od, config.at_most_two_errors_per_detector ? next_next_blocked_errs : next_blocked_errs, next_det_counts);
414395
}
415396

416397
if (next_cost == INF) {

src/tesseract.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <string>
2020
#include <unordered_map>
2121
#include <vector>
22+
#include <unordered_set>
2223

2324
#include "common.h"
2425
#include "stim.h"
@@ -70,10 +71,12 @@ struct TesseractDecoder {
7071
// these detection events, using a specified detector ordering index.
7172
void decode_to_errors(const std::vector<uint64_t>& detections,
7273
size_t det_order);
74+
7375
// Returns the bitwise XOR of all the observables bitmasks of all errors in
7476
// the predicted errors buffer.
7577
common::ObservablesMask mask_from_errors(
7678
const std::vector<size_t>& predicted_errors);
79+
7780
// Returns the sum of the likelihood costs (minus-log-likelihood-ratios) of
7881
// all errors in the predicted errors buffer.
7982
double cost_from_errors(const std::vector<size_t>& predicted_errors);
@@ -97,8 +100,7 @@ struct TesseractDecoder {
97100

98101
void initialize_structures(size_t num_detectors);
99102
double get_detcost(size_t d, const std::vector<char>& blocked_errs,
100-
const std::vector<size_t>& det_counts,
101-
const std::vector<char>& dets) const;
103+
const std::vector<size_t>& det_counts) const;
102104
void to_node(const QNode& qnode, const std::vector<char>& shot_dets,
103105
size_t det_order, Node& node) const;
104106
};

0 commit comments

Comments
 (0)