Skip to content

Commit 4076180

Browse files
Changes 'next_errors' variable in main decoding loop to a linked list (#173)
* Prevents unnecessary vector copies as we can just rebuild the history when we need to. * Also use arena allocator to prevent repeated heap allocation. Tested with: ``` (venv13) dandragona@kaonashi:~/Documents/tesseract/indep/optimizations$ bazel-8.2.1 test --cache_test_results=no //src:all //src/py:all INFO: Analyzed 24 targets (0 packages loaded, 23 targets configured). INFO: Found 16 targets and 8 test targets... INFO: Elapsed time: 6.390s, Critical Path: 5.97s INFO: 9 processes: 42 action cache hit, 12 linux-sandbox, 2 local. INFO: Build completed successfully, 9 total actions //src:common_tests PASSED in 0.0s //src:tesseract_tests PASSED in 0.7s //src/py:common_test PASSED in 1.5s //src/py:requirements_test PASSED in 5.9s //src/py:simplex_test PASSED in 1.6s //src/py:tesseract_sinter_compat_test PASSED in 4.6s //src/py:tesseract_test PASSED in 1.5s //src/py:utils_test PASSED in 1.0s Executed 8 out of 8 tests: 8 tests pass. ``` Additional accuracy comparisons: ``` # d=11, p=0.001 Surface Code Transversal CNOT X # baseline: num_shots = 5000 num_low_confidence = 0 num_errors = 25 total_time_seconds = 364.1171770000001 # arena allocator + linked list optimization: num_shots = 5000 num_low_confidence = 0 num_errors = 25 total_time_seconds = 253.3283590000004 # d=9, p=.002 Superdense Color Code X # baseline: num_shots = 5000 num_low_confidence = 0 num_errors = 64 total_time_seconds = 604.9724190000003 # arena allocator + linked list optimization: num_shots = 5000 num_low_confidence = 0 num_errors = 64 total_time_seconds = 585.220854999997 # d=12, p=0.002 Bivariate Bicycle X # baseline: num_shots = 5000 num_low_confidence = 0 num_errors = 15 total_time_seconds = 3232.868145999989 # arena allocator + linked list optimization: num_shots = 5000 num_low_confidence = 0 num_errors = 15 total_time_seconds = 3099.478287000002 # d=23, p=0.008 Surface Code Unrotated Memory Z # baseline: num_shots = 5000 num_low_confidence = 0 num_errors = 0 total_time_seconds = 803.0619609999993 # arena allocator + linked list optimization: num_shots = 5000 num_low_confidence = 0 num_errors = 0 total_time_seconds = 142.359363 ``` Co-authored-by: Noah Shutty <noajshu@users.noreply.github.com>
1 parent bd84959 commit 4076180

5 files changed

Lines changed: 60 additions & 29 deletions

File tree

src/common.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@ struct Symptom {
4545
std::string str() const;
4646
};
4747

48+
// Represents a specific subset of errors in the power set of all errors.
49+
// `parent_idx` is the index of the parent node in the error chain arena, and is
50+
// used to trace back to the root of the error set.
51+
struct ErrorChainNode {
52+
size_t error_index;
53+
size_t min_detector;
54+
int64_t parent_idx = -1;
55+
};
56+
4857
// Represents an error / weighted hyperedge
4958
struct Error {
5059
double likelihood_cost;

src/tesseract.cc

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ std::string Node::str() {
7272
std::stringstream ss;
7373
auto& self = *this;
7474
ss << "Node(";
75-
ss << "errors=" << self.errors << ", ";
75+
ss << "error_chain_idx=" << self.error_chain_idx << ", ";
7676
ss << "cost=" << self.cost << ", ";
7777
ss << "num_dets=" << self.num_dets << ", ";
7878
return ss.str();
@@ -243,16 +243,13 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections)
243243
}
244244

245245
void TesseractDecoder::flip_detectors_and_block_errors(
246-
size_t detector_order, const std::vector<size_t>& errors, boost::dynamic_bitset<>& detectors,
246+
size_t detector_order, int64_t error_chain_idx, boost::dynamic_bitset<>& detectors,
247247
std::vector<DetectorCostTuple>& detector_cost_tuples) const {
248-
for (size_t ei : errors) {
249-
size_t min_detector = std::numeric_limits<size_t>::max();
250-
for (size_t d = 0; d < num_detectors; ++d) {
251-
if (detectors[config.det_orders[detector_order][d]]) {
252-
min_detector = config.det_orders[detector_order][d];
253-
break;
254-
}
255-
}
248+
int64_t walker_idx = error_chain_idx;
249+
while (walker_idx != -1) {
250+
const auto& node = error_chain_arena[walker_idx];
251+
size_t ei = node.error_index;
252+
size_t min_detector = node.min_detector;
256253

257254
for (int oei : d2e[min_detector]) {
258255
detector_cost_tuples[oei].error_blocked = 1;
@@ -262,13 +259,18 @@ void TesseractDecoder::flip_detectors_and_block_errors(
262259
for (int d : edets[ei]) {
263260
detectors[d] = !detectors[d];
264261
}
262+
walker_idx = node.parent_idx;
265263
}
266264
}
267265

268266
void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
269267
size_t detector_order, size_t detector_beam) {
270268
predicted_errors_buffer.clear();
271269
low_confidence_flag = false;
270+
error_chain_arena.clear();
271+
// Can technically be larger than pqlimit, but we need an initial guess on how many nodes we
272+
// will process from the queue.
273+
error_chain_arena.reserve(config.pqlimit);
272274

273275
std::priority_queue<Node, std::vector<Node>, std::greater<Node>> pq;
274276
std::unordered_map<size_t, std::unordered_set<boost::dynamic_bitset<>>> visited_detectors;
@@ -296,11 +298,10 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
296298
size_t min_num_dets = detections.size();
297299
size_t max_num_dets = min_num_dets + detector_beam;
298300

299-
std::vector<size_t> next_errors;
300301
boost::dynamic_bitset<> next_detectors;
301302
std::vector<DetectorCostTuple> next_detector_cost_tuples;
302303

303-
pq.push({initial_cost, min_num_dets, std::vector<size_t>()});
304+
pq.push({initial_cost, min_num_dets, 0, -1});
304305
size_t num_pq_pushed = 1;
305306

306307
while (!pq.empty()) {
@@ -311,17 +312,20 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
311312

312313
boost::dynamic_bitset<> detectors = initial_detectors;
313314
std::vector<DetectorCostTuple> detector_cost_tuples(num_errors);
314-
flip_detectors_and_block_errors(detector_order, node.errors, detectors, detector_cost_tuples);
315+
flip_detectors_and_block_errors(detector_order, node.error_chain_idx, detectors,
316+
detector_cost_tuples);
315317

316318
if (node.num_dets == 0) {
317319
if (config.create_visualization) {
318-
visualizer.add_activated_errors(node.errors);
320+
visualizer.add_activated_errors(node.error_chain_idx, error_chain_arena);
319321
visualizer.add_activated_detectors(detectors, num_detectors);
320322
}
321323
if (config.verbose) {
322324
std::cout << "activated_errors = ";
323-
for (size_t oei : node.errors) {
324-
std::cout << oei << ", ";
325+
int64_t walker_idx = node.error_chain_idx;
326+
while (walker_idx != -1) {
327+
std::cout << error_chain_arena[walker_idx].error_index << ", ";
328+
walker_idx = error_chain_arena[walker_idx].parent_idx;
325329
}
326330
std::cout << std::endl;
327331
std::cout << "activated_detectors = ";
@@ -335,15 +339,20 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
335339
std::cout << "Decoding complete. Cost: " << node.cost
336340
<< " num_pq_pushed = " << num_pq_pushed << std::endl;
337341
}
338-
predicted_errors_buffer = node.errors;
342+
predicted_errors_buffer.resize(node.depth);
343+
int64_t walker_idx = node.error_chain_idx;
344+
for (size_t i = 0; i < node.depth; ++i) {
345+
predicted_errors_buffer[node.depth - 1 - i] = error_chain_arena[walker_idx].error_index;
346+
walker_idx = error_chain_arena[walker_idx].parent_idx;
347+
}
339348
return;
340349
}
341350

342351
if (config.no_revisit_dets && !visited_detectors[node.num_dets].insert(detectors).second)
343352
continue;
344353

345354
if (config.create_visualization) {
346-
visualizer.add_activated_errors(node.errors);
355+
visualizer.add_activated_errors(node.error_chain_idx, error_chain_arena);
347356
visualizer.add_activated_detectors(detectors, num_detectors);
348357
}
349358
if (config.verbose) {
@@ -352,8 +361,10 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
352361
std::cout << "num_dets = " << node.num_dets << " max_num_dets = " << max_num_dets
353362
<< " cost = " << node.cost << std::endl;
354363
std::cout << "activated_errors = ";
355-
for (size_t oei : node.errors) {
356-
std::cout << oei << ", ";
364+
int64_t walker_idx = node.error_chain_idx;
365+
while (walker_idx != -1) {
366+
std::cout << error_chain_arena[walker_idx].error_index << ", ";
367+
walker_idx = error_chain_arena[walker_idx].parent_idx;
357368
}
358369
std::cout << std::endl;
359370
std::cout << "activated_detectors = ";
@@ -408,8 +419,13 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
408419
}
409420
prev_ei = ei;
410421

411-
next_errors = node.errors;
412-
next_errors.push_back(ei);
422+
// Create the error chain node for this candidate.
423+
error_chain_arena.emplace_back();
424+
auto& next_node = error_chain_arena.back();
425+
next_node.error_index = ei;
426+
next_node.min_detector = min_detector;
427+
next_node.parent_idx = node.error_chain_idx;
428+
413429
next_detectors = detectors;
414430
next_detector_cost_tuples[ei].error_blocked = 1;
415431

@@ -453,7 +469,7 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
453469

454470
if (next_cost == INF) continue;
455471

456-
pq.push({next_cost, next_num_dets, next_errors});
472+
pq.push({next_cost, next_num_dets, node.depth + 1, (int64_t)(error_chain_arena.size() - 1)});
457473
++num_pq_pushed;
458474

459475
if (num_pq_pushed > config.pqlimit) {

src/tesseract.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ class Node {
5252
double cost;
5353
// The number of activated detectors (dets for short) at this node
5454
size_t num_dets;
55-
std::vector<size_t> errors;
55+
size_t depth;
56+
int64_t error_chain_idx = -1;
5657

5758
bool operator>(const Node& other) const;
5859
std::string str();
@@ -111,10 +112,11 @@ struct TesseractDecoder {
111112
std::vector<std::vector<int>> edets;
112113
size_t num_errors;
113114
std::vector<ErrorCost> error_costs;
115+
std::vector<common::ErrorChainNode> error_chain_arena;
114116

115117
void initialize_structures(size_t num_detectors);
116118
double get_detcost(size_t d, const std::vector<DetectorCostTuple>& detector_cost_tuples) const;
117-
void flip_detectors_and_block_errors(size_t detector_order, const std::vector<size_t>& errors,
119+
void flip_detectors_and_block_errors(size_t detector_order, int64_t error_chain_idx,
118120
boost::dynamic_bitset<>& detectors,
119121
std::vector<DetectorCostTuple>& detector_cost_tuples) const;
120122
};

src/visualization.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,15 @@ void Visualizer::add_detector_coords(const std::vector<std::vector<double>>& det
2020
}
2121
}
2222

23-
void Visualizer::add_activated_errors(const std::vector<size_t>& activated_errors) {
23+
void Visualizer::add_activated_errors(int64_t node_idx,
24+
const std::vector<common::ErrorChainNode>& arena) {
2425
std::stringstream ss;
2526
ss << "activated_errors = ";
26-
for (size_t oei : activated_errors) {
27-
ss << oei << ", ";
27+
int64_t walker_idx = node_idx;
28+
while (walker_idx != -1) {
29+
const auto& node = arena[walker_idx];
30+
ss << node.error_index << ", ";
31+
walker_idx = node.parent_idx;
2832
}
2933
lines.push_back(ss.str());
3034
}

src/visualization.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
struct Visualizer {
1111
void add_detector_coords(const std::vector<std::vector<double>>&);
1212
void add_errors(const std::vector<common::Error>&);
13-
void add_activated_errors(const std::vector<size_t>&);
13+
void add_activated_errors(int64_t node_idx, const std::vector<common::ErrorChainNode>& arena);
1414
void add_activated_detectors(const boost::dynamic_bitset<>&, size_t);
1515

1616
void write(const char* fpath);

0 commit comments

Comments
 (0)