Skip to content

Commit a72f33b

Browse files
committed
Add multi-pass Tesseract decoder
Add the multi-pass Tesseract decoder, which decomposes a detector error model into independent components by detector class and decodes each component separately across multiple passes. Between passes, first-pass decoding correlations are used to reweight error probabilities in subsequent components, improving accuracy. Key components: - MultiPassTesseractDecoder: core decoder with static and causal scheduling across detector classes - FastTwoPassTesseractDecoder: optimised two-pass specialisation - multi_pass_sinter_compat.pybind.h: pybind11 bindings exposing MultiPassSinterDecoder and MultiPassSinterCompiledDecoder - Python integration tests for multi-pass bindings - Theory and architecture documentation Performance: 10-100x wall-clock speedup over single-pass Tesseract by decomposing the DEM into smaller independent components.
1 parent b8f4393 commit a72f33b

10 files changed

Lines changed: 924 additions & 0 deletions

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,6 @@ user.bazelrc
3939
src/tesseract_decoder*.so
4040

4141
MODULE.bazel.lock
42+
build/
43+
_core.so
44+
*.egg-info/

BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,8 @@ config_setting(
5959
"@platforms//cpu:x86_64",
6060
],
6161
)
62+
filegroup(
63+
name = "testdata",
64+
srcs = glob(["testdata/**/*"]),
65+
visibility = ["//visibility:public"],
66+
)

CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ target_include_directories(dem_decomposition PUBLIC ${TESSERACT_SRC_DIR})
113113
target_compile_options(dem_decomposition PRIVATE ${OPT_COPTS})
114114
target_link_libraries(dem_decomposition PUBLIC bern_utils libstim)
115115

116+
add_library(multi_pass_tesseract_decoder ${TESSERACT_SRC_DIR}/multi_pass_tesseract_decoder.cc ${TESSERACT_SRC_DIR}/multi_pass_tesseract_decoder.h)
117+
target_include_directories(multi_pass_tesseract_decoder PUBLIC ${TESSERACT_SRC_DIR})
118+
target_compile_options(multi_pass_tesseract_decoder PRIVATE ${OPT_COPTS})
119+
target_link_libraries(multi_pass_tesseract_decoder PUBLIC tesseract_lib tanner_graph error_correlations dem_decomposition libstim)
120+
116121
add_library(tesseract_lib ${TESSERACT_SRC_DIR}/tesseract.cc ${TESSERACT_SRC_DIR}/tesseract.h)
117122
target_include_directories(tesseract_lib PUBLIC ${TESSERACT_SRC_DIR})
118123
target_compile_options(tesseract_lib PRIVATE ${OPT_COPTS})
@@ -137,6 +142,7 @@ target_link_libraries(simplex_bin PRIVATE common simplex argparse::argparse nloh
137142
pybind11_add_module(_core MODULE ${TESSERACT_SRC_DIR}/tesseract.pybind.cc)
138143
target_compile_options(_core PRIVATE ${OPT_COPTS})
139144
target_include_directories(_core PRIVATE ${TESSERACT_SRC_DIR})
145+
target_link_libraries(_core PRIVATE common utils simplex tesseract_lib multi_pass_tesseract_decoder)
140146
set_target_properties(_core PROPERTIES
141147
LIBRARY_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}/tesseract_decoder
142148
LIBRARY_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/tesseract_decoder
@@ -168,3 +174,7 @@ add_executable(error_correlations_test ${TESSERACT_SRC_DIR}/error_correlations.t
168174
target_link_libraries(error_correlations_test PRIVATE error_correlations GTest::gtest_main libstim)
169175
add_test(NAME error_correlations_test COMMAND error_correlations_test)
170176

177+
add_executable(multi_pass_tesseract_decoder_test ${TESSERACT_SRC_DIR}/multi_pass_tesseract_decoder.test.cc)
178+
target_link_libraries(multi_pass_tesseract_decoder_test PRIVATE multi_pass_tesseract_decoder GTest::gtest_main libstim)
179+
add_test(NAME multi_pass_tesseract_decoder_test COMMAND multi_pass_tesseract_decoder_test)
180+

src/BUILD

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,15 @@ pybind_library(
7171
"visualization.pybind.h",
7272
"tesseract.pybind.h",
7373
"tesseract_sinter_compat.pybind.h",
74+
"multi_pass_sinter_compat.pybind.h",
7475
],
7576
copts = OPT_COPTS,
7677
deps = [
7778
":libcommon",
7879
":libutils",
7980
":libsimplex",
8081
":libtesseract",
82+
":libmulti_pass_tesseract_decoder",
8183
],
8284
)
8385

@@ -132,6 +134,35 @@ cc_library(
132134
],
133135
)
134136

137+
cc_library(
138+
name = "libmulti_pass_tesseract_decoder",
139+
srcs = ["multi_pass_tesseract_decoder.cc"],
140+
hdrs = ["multi_pass_tesseract_decoder.h"],
141+
copts = OPT_COPTS,
142+
linkopts = OPT_LINKOPTS,
143+
deps = [
144+
":libtesseract",
145+
":libtanner_graph",
146+
":liberror_correlations",
147+
":libdem_decomposition",
148+
"@stim//:stim_lib",
149+
],
150+
)
151+
152+
cc_test(
153+
name = "multi_pass_tesseract_decoder_tests",
154+
srcs = ["multi_pass_tesseract_decoder.test.cc"],
155+
copts = OPT_COPTS,
156+
linkopts = OPT_LINKOPTS,
157+
data = ["//:testdata"],
158+
deps = [
159+
":libmulti_pass_tesseract_decoder",
160+
"@gtest",
161+
"@gtest//:gtest_main",
162+
"@stim//:stim_lib",
163+
],
164+
)
165+
135166
cc_library(
136167
name = "liberror_correlations",
137168
srcs = ["error_correlations.cc"],
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
#ifndef MULTI_PASS_SINTER_COMPAT_PYBIND_H
2+
#define MULTI_PASS_SINTER_COMPAT_PYBIND_H
3+
4+
#include <pybind11/iostream.h>
5+
#include <pybind11/numpy.h>
6+
#include <pybind11/pybind11.h>
7+
#include <pybind11/stl.h>
8+
#include <pybind11/functional.h>
9+
#include <iostream>
10+
11+
#include "multi_pass_tesseract_decoder.h"
12+
#include "dem_decomposition.h"
13+
#include "utils.h"
14+
15+
namespace py = pybind11;
16+
17+
namespace tesseract {
18+
19+
struct MultiPassSinterCompiledDecoder {
20+
std::unique_ptr<tesseract::MultiPassTesseractDecoder> decoder;
21+
uint64_t num_detectors;
22+
uint64_t num_observables;
23+
24+
MultiPassSinterCompiledDecoder(std::unique_ptr<tesseract::MultiPassTesseractDecoder> d, uint64_t nd, uint64_t no)
25+
: decoder(std::move(d)), num_detectors(nd), num_observables(no) {}
26+
27+
size_t num_components() const { return decoder->num_components(); }
28+
29+
py::array_t<uint8_t> decode_shots_bit_packed(const py::array_t<uint8_t>& bit_packed_detection_event_data) {
30+
if (bit_packed_detection_event_data.ndim() != 2) throw std::invalid_argument("Input must be 2D.");
31+
const uint64_t num_detector_bytes = (num_detectors + 7) / 8;
32+
if (bit_packed_detection_event_data.shape(1) != (py::ssize_t)num_detector_bytes) throw std::invalid_argument("Wrong shape.");
33+
34+
const size_t num_shots = bit_packed_detection_event_data.shape(0);
35+
const uint64_t num_observable_bytes = (num_observables + 7) / 8;
36+
37+
auto result_array = py::array_t<uint8_t>({(py::ssize_t)num_shots, (py::ssize_t)num_observable_bytes});
38+
auto result_buffer = result_array.mutable_data();
39+
40+
const uint8_t* detections_data = bit_packed_detection_event_data.data();
41+
const size_t detections_stride = bit_packed_detection_event_data.strides(0);
42+
43+
for (size_t shot = 0; shot < num_shots; ++shot) {
44+
const uint8_t* single_shot_data = detections_data + shot * detections_stride;
45+
std::vector<uint64_t> detections;
46+
for (uint64_t i = 0; i < num_detectors; ++i) {
47+
if ((single_shot_data[i / 8] >> (i % 8)) & 1) detections.push_back(i);
48+
}
49+
50+
std::vector<int> predictions = decoder->decode(detections);
51+
uint8_t* single_result_buffer = result_buffer + shot * num_observable_bytes;
52+
std::fill(single_result_buffer, single_result_buffer + num_observable_bytes, 0);
53+
for (int obs_index : predictions) {
54+
if (obs_index >= 0 && (uint64_t)obs_index < num_observables) {
55+
single_result_buffer[obs_index / 8] ^= (1 << (obs_index % 8));
56+
}
57+
}
58+
}
59+
return result_array;
60+
}
61+
};
62+
63+
struct MultiPassSinterDecoder {
64+
size_t num_passes;
65+
py::object full_decomposer;
66+
py::object detector_classifier;
67+
TesseractConfig base_config;
68+
size_t num_det_orders;
69+
::DetOrder det_order_method;
70+
uint64_t seed;
71+
SchedulingStrategy strategy;
72+
73+
MultiPassSinterDecoder(size_t n=2) : num_passes(n), full_decomposer(py::none()), detector_classifier(py::none()), num_det_orders(1), det_order_method(::DetOrder::DetBFS), seed(0), strategy(SchedulingStrategy::Static) {}
74+
75+
MultiPassSinterCompiledDecoder compile_decoder_for_dem(const py::object& dem) {
76+
stim::DetectorErrorModel stim_dem;
77+
78+
if (!full_decomposer.is_none()) {
79+
py::gil_scoped_acquire acquire;
80+
py::object decomposed_py_dem = full_decomposer(dem);
81+
stim_dem = stim::DetectorErrorModel(py::cast<std::string>(py::str(decomposed_py_dem)).c_str());
82+
} else {
83+
stim_dem = stim::DetectorErrorModel(py::cast<std::string>(py::str(dem)).c_str());
84+
}
85+
86+
std::vector<int> classification;
87+
if (py::isinstance<py::function>(detector_classifier)) {
88+
uint64_t num_dets = stim_dem.count_detectors();
89+
90+
std::set<uint64_t> detector_ids;
91+
std::map<uint64_t, std::string> tags;
92+
for (const auto& inst : stim_dem.flattened().instructions) {
93+
if (inst.type == stim::DemInstructionType::DEM_DETECTOR) {
94+
uint64_t d = inst.target_data[0].val();
95+
detector_ids.insert(d);
96+
tags[d] = inst.tag;
97+
}
98+
}
99+
auto coords_map = stim_dem.get_detector_coordinates(detector_ids);
100+
101+
for (uint64_t i = 0; i < num_dets; ++i) {
102+
std::vector<double> c = coords_map.count(i) ? coords_map.at(i) : std::vector<double>{};
103+
std::string t = tags.count(i) ? tags.at(i) : "";
104+
py::gil_scoped_acquire acquire;
105+
classification.push_back(py::cast<int>(detector_classifier((int)i, c, t)));
106+
}
107+
}
108+
109+
tesseract::DetectorClassifier classifier = [classification](int index, const std::vector<double>& coords, const std::string& tag) -> int {
110+
if (index >= 0 && (size_t)index < classification.size()) return classification[index];
111+
return 0;
112+
};
113+
114+
auto decoder = std::make_unique<tesseract::MultiPassTesseractDecoder>(stim_dem, num_passes, classifier, base_config, num_det_orders, det_order_method, seed, strategy);
115+
116+
return MultiPassSinterCompiledDecoder(std::move(decoder), stim_dem.count_detectors(), stim_dem.count_observables());
117+
}
118+
};
119+
120+
void pybind_multi_pass_sinter_compat(py::module& m) {
121+
py::enum_<SchedulingStrategy>(m, "SchedulingStrategy")
122+
.value("Static", SchedulingStrategy::Static)
123+
.value("Causal", SchedulingStrategy::Causal)
124+
.export_values();
125+
126+
py::class_<MultiPassSinterCompiledDecoder>(m, "MultiPassSinterCompiledDecoder")
127+
.def_property_readonly("num_components", &MultiPassSinterCompiledDecoder::num_components)
128+
.def("decode_shots_bit_packed", &MultiPassSinterCompiledDecoder::decode_shots_bit_packed,
129+
py::kw_only(), py::arg("bit_packed_detection_event_data"));
130+
131+
py::class_<MultiPassSinterDecoder>(m, "MultiPassSinterDecoder")
132+
.def(py::init<size_t>(), py::arg("num_passes") = 2)
133+
.def_readwrite("full_decomposer", &MultiPassSinterDecoder::full_decomposer)
134+
.def_readwrite("detector_classifier", &MultiPassSinterDecoder::detector_classifier)
135+
.def_readwrite("base_config", &MultiPassSinterDecoder::base_config)
136+
.def_readwrite("num_det_orders", &MultiPassSinterDecoder::num_det_orders)
137+
.def_readwrite("det_order_method", &MultiPassSinterDecoder::det_order_method)
138+
.def_readwrite("seed", &MultiPassSinterDecoder::seed)
139+
.def_readwrite("strategy", &MultiPassSinterDecoder::strategy)
140+
.def("compile_decoder_for_dem", &MultiPassSinterDecoder::compile_decoder_for_dem,
141+
py::kw_only(), py::arg("dem"));
142+
}
143+
144+
} // namespace tesseract
145+
146+
#endif // MULTI_PASS_SINTER_COMPAT_PYBIND_H

0 commit comments

Comments
 (0)