Skip to content

Commit 2625f6e

Browse files
authored
Merge branch 'main' into optimization
2 parents 6d815f8 + 515625c commit 2625f6e

7 files changed

Lines changed: 532 additions & 58 deletions

File tree

src/common.pybind.h

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,41 @@
11
#ifndef TESSERACT_COMMON_PY_H
22
#define TESSERACT_COMMON_PY_H
33

4-
#include <vector>
5-
4+
#include <pybind11/operators.h>
65
#include <pybind11/pybind11.h>
76
#include <pybind11/stl.h>
8-
#include <pybind11/operators.h>
7+
8+
#include <vector>
99

1010
#include "common.h"
1111

1212
namespace py = pybind11;
1313

14-
void add_common_module(py::module &root)
15-
{
16-
auto m = root.def_submodule("common", "classes commonly used by the decoder");
17-
18-
// TODO: add as_dem_instruction_targets
19-
py::class_<common::Symptom>(m, "Symptom")
20-
.def(py::init<std::vector<int>, common::ObservablesMask>(), py::arg("detectors") = std::vector<int>(), py::arg("observables") = 0)
21-
.def_readwrite("detectors", &common::Symptom::detectors)
22-
.def_readwrite("observables", &common::Symptom::observables)
23-
.def("__str__", &common::Symptom::str)
24-
.def(py::self == py::self)
25-
.def(py::self != py::self);
26-
27-
// TODO: add constructor with stim::DemInstruction.
28-
py::class_<common::Error>(m, "Error")
29-
.def_readwrite("likelihood_cost", &common::Error::likelihood_cost)
30-
.def_readwrite("probability", &common::Error::probability)
31-
.def_readwrite("symptom", &common::Error::symptom)
32-
.def("__str__", &common::Error::str)
33-
.def(py::init<>())
34-
.def(py::init<double, std::vector<int> &, common::ObservablesMask, std::vector<bool> &>())
35-
.def(py::init<double, double, std::vector<int> &, common::ObservablesMask, std::vector<bool> &>());
14+
void add_common_module(py::module &root) {
15+
auto m = root.def_submodule("common", "classes commonly used by the decoder");
16+
17+
// TODO: add as_dem_instruction_targets
18+
py::class_<common::Symptom>(m, "Symptom")
19+
.def(py::init<std::vector<int>, common::ObservablesMask>(),
20+
py::arg("detectors") = std::vector<int>(),
21+
py::arg("observables") = 0)
22+
.def_readwrite("detectors", &common::Symptom::detectors)
23+
.def_readwrite("observables", &common::Symptom::observables)
24+
.def("__str__", &common::Symptom::str)
25+
.def(py::self == py::self)
26+
.def(py::self != py::self);
27+
28+
// TODO: add constructor with stim::DemInstruction.
29+
py::class_<common::Error>(m, "Error")
30+
.def_readwrite("likelihood_cost", &common::Error::likelihood_cost)
31+
.def_readwrite("probability", &common::Error::probability)
32+
.def_readwrite("symptom", &common::Error::symptom)
33+
.def("__str__", &common::Error::str)
34+
.def(py::init<>())
35+
.def(py::init<double, std::vector<int> &, common::ObservablesMask,
36+
std::vector<bool> &>())
37+
.def(py::init<double, double, std::vector<int> &, common::ObservablesMask,
38+
std::vector<bool> &>());
3639
}
3740

3841
#endif

src/tesseract.cc

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ TesseractDecoder::TesseractDecoder(TesseractConfig config_) : config(config_) {
4747
}
4848
assert(config.det_orders.size());
4949
errors = get_errors_from_dem(config.dem.flattened());
50+
if (config.verbose) {
51+
for (auto& error : errors) {
52+
std::cout << error.str() << std::endl;
53+
}
54+
}
5055
num_detectors = config.dem.count_detectors();
5156
num_errors = config.dem.count_errors();
5257
initialize_structures(config.dem.count_detectors());
@@ -86,21 +91,24 @@ void TesseractDecoder::initialize_structures(size_t num_detectors) {
8691

8792
struct VectorCharHash {
8893
size_t operator()(const std::vector<char>& v) const {
89-
size_t seed = v.size(); // Still good practice to incorporate vector size
94+
size_t seed = v.size(); // Still good practice to incorporate vector size
9095

91-
// Iterate over char elements. Accessing 'b_val' is now a direct memory read.
96+
// Iterate over char elements. Accessing 'b_val' is now a direct memory
97+
// read.
9298
for (char b_val : v) {
9399
// The polynomial rolling hash with 31 (or another prime)
94100
// 'b_val' is already a char (an 8-bit integer).
95-
// static_cast<size_t>(b_val) ensures it's promoted to size_t before arithmetic.
96-
// This cast is efficient (likely a simple register extension/move).
101+
// static_cast<size_t>(b_val) ensures it's promoted to size_t before
102+
// arithmetic. This cast is efficient (likely a simple register
103+
// extension/move).
97104
seed = seed * 31 + static_cast<size_t>(b_val);
98105
}
99106
return seed;
100107
}
101108
};
102109

103-
void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections) {
110+
void TesseractDecoder::decode_to_errors(
111+
const std::vector<uint64_t>& detections) {
104112
std::vector<size_t> best_errors;
105113
double best_cost = std::numeric_limits<double>::max();
106114
assert(config.det_orders.size());
@@ -246,6 +254,18 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
246254

247255
if (node.num_dets == 0) {
248256
if (config.verbose) {
257+
std::cout << "activated_errors = ";
258+
for (size_t oei : node.errs) {
259+
std::cout << oei << ", ";
260+
}
261+
std::cout << std::endl;
262+
std::cout << "activated_dets = ";
263+
for (size_t d = 0; d < num_detectors; ++d) {
264+
if (node.dets[d]) {
265+
std::cout << d << ", ";
266+
}
267+
}
268+
std::cout << std::endl;
249269
std::cout.precision(13);
250270
std::cout << "Decoding complete. Cost: " << node.cost
251271
<< " num_pq_pushed = " << num_pq_pushed << std::endl;
@@ -269,10 +289,18 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
269289
std::cout << "num_dets = " << node.num_dets
270290
<< " max_num_dets = " << max_num_dets << " cost = " << node.cost
271291
<< std::endl;
292+
std::cout << "activated_errors = ";
272293
for (size_t oei : node.errs) {
273294
std::cout << oei << ", ";
274295
}
275296
std::cout << std::endl;
297+
std::cout << "activated_dets = ";
298+
for (size_t d = 0; d < num_detectors; ++d) {
299+
if (node.dets[d]) {
300+
std::cout << d << ", ";
301+
}
302+
}
303+
std::cout << std::endl;
276304
}
277305

278306
if (node.num_dets < min_num_dets) {

src/tesseract.pybind.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
#include <pybind11/pybind11.h>
2-
#include "pybind11/detail/common.h"
32

43
#include "common.pybind.h"
4+
#include "pybind11/detail/common.h"
55

6-
PYBIND11_MODULE(tesseract_py, m)
7-
{
8-
add_common_module(m);
9-
}
6+
PYBIND11_MODULE(tesseract_py, m) { add_common_module(m); }

src/tesseract_main.cc

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -179,34 +179,58 @@ struct Args {
179179

180180
std::vector<std::vector<double>> detector_coords =
181181
get_detector_coords(config.dem);
182+
if (verbose) {
183+
for (size_t d = 0; d < detector_coords.size(); ++d) {
184+
std::cout << "Detector D" << d << " coordinate (";
185+
size_t e = std::min(3ul, detector_coords[d].size());
186+
for (size_t i = 0; i < e; ++i) {
187+
std::cout << detector_coords[d][i];
188+
if (i + 1 < e) std::cout << ", ";
189+
}
190+
std::cout << ")" << std::endl;
191+
}
192+
}
182193

183194
std::vector<double> inner_products(config.dem.count_detectors());
184195

185-
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
186-
// Sample a direction
187-
std::vector<double> orientation_vector;
188-
for (size_t i = 0; i < detector_coords.at(0).size(); ++i) {
189-
orientation_vector.push_back(dist(rng));
196+
if (!detector_coords.size() or !detector_coords.at(0).size()) {
197+
// If there are no detector coordinates, just use the standard ordering
198+
// of the indices.
199+
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
200+
config.det_orders.emplace_back();
201+
std::iota(config.det_orders.back().begin(),
202+
config.det_orders.front().end(), 0);
190203
}
204+
} else {
205+
// Use the coordinates to order the detectors based on a random
206+
// orientation
207+
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
208+
// Sample a direction
209+
std::vector<double> orientation_vector;
210+
for (size_t i = 0; i < detector_coords.at(0).size(); ++i) {
211+
orientation_vector.push_back(dist(rng));
212+
}
191213

192-
for (size_t i = 0; i < detector_coords.size(); ++i) {
193-
inner_products[i] = 0;
194-
for (size_t j = 0; j < orientation_vector.size(); ++j) {
195-
inner_products[i] += detector_coords[i][j] * orientation_vector[j];
214+
for (size_t i = 0; i < detector_coords.size(); ++i) {
215+
inner_products[i] = 0;
216+
for (size_t j = 0; j < orientation_vector.size(); ++j) {
217+
inner_products[i] +=
218+
detector_coords[i][j] * orientation_vector[j];
219+
}
196220
}
221+
std::vector<size_t> perm(config.dem.count_detectors());
222+
std::iota(perm.begin(), perm.end(), 0);
223+
std::sort(perm.begin(), perm.end(),
224+
[&](const size_t& i, const size_t& j) {
225+
return inner_products[i] > inner_products[j];
226+
});
227+
// Invert the permutation
228+
std::vector<size_t> inv_perm(config.dem.count_detectors());
229+
for (size_t i = 0; i < perm.size(); ++i) {
230+
inv_perm[perm[i]] = i;
231+
}
232+
config.det_orders[det_order] = inv_perm;
197233
}
198-
std::vector<size_t> perm(config.dem.count_detectors());
199-
std::iota(perm.begin(), perm.end(), 0);
200-
std::sort(perm.begin(), perm.end(),
201-
[&](const size_t& i, const size_t& j) {
202-
return inner_products[i] > inner_products[j];
203-
});
204-
// Invert the permutation
205-
std::vector<size_t> inv_perm(config.dem.count_detectors());
206-
for (size_t i = 0; i < perm.size(); ++i) {
207-
inv_perm[perm[i]] = i;
208-
}
209-
config.det_orders[det_order] = inv_perm;
210234
}
211235
}
212236

src/utils.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ std::vector<std::vector<double>> get_detector_coords(
3838
}
3939
case stim::DemInstructionType::DEM_DETECTOR: {
4040
std::vector<double> coord;
41-
for (const stim::DemTarget& t : instruction.target_data) {
42-
coord.push_back(t.val());
41+
for (const double& t : instruction.arg_data) {
42+
coord.push_back(t);
4343
}
4444
detector_coords.push_back(coord);
4545
break;

0 commit comments

Comments
 (0)