|
| 1 | +#ifndef MULTI_PASS_SINTER_COMPAT_PYBIND_H |
| 2 | +#define MULTI_PASS_SINTER_COMPAT_PYBIND_H |
| 3 | + |
| 4 | +#include <pybind11/iostream.h> |
| 5 | +#include <pybind11/numpy.h> |
| 6 | +#include <pybind11/pybind11.h> |
| 7 | +#include <pybind11/stl.h> |
| 8 | +#include <pybind11/functional.h> |
| 9 | +#include <iostream> |
| 10 | + |
| 11 | +#include "multi_pass_tesseract_decoder.h" |
| 12 | +#include "dem_decomposition.h" |
| 13 | +#include "utils.h" |
| 14 | + |
| 15 | +namespace py = pybind11; |
| 16 | + |
| 17 | +namespace tesseract { |
| 18 | + |
| 19 | +struct MultiPassSinterCompiledDecoder { |
| 20 | + std::unique_ptr<tesseract::MultiPassTesseractDecoder> decoder; |
| 21 | + uint64_t num_detectors; |
| 22 | + uint64_t num_observables; |
| 23 | + |
| 24 | + MultiPassSinterCompiledDecoder(std::unique_ptr<tesseract::MultiPassTesseractDecoder> d, uint64_t nd, uint64_t no) |
| 25 | + : decoder(std::move(d)), num_detectors(nd), num_observables(no) {} |
| 26 | + |
| 27 | + size_t num_components() const { return decoder->num_components(); } |
| 28 | + |
| 29 | + py::array_t<uint8_t> decode_shots_bit_packed(const py::array_t<uint8_t>& bit_packed_detection_event_data) { |
| 30 | + if (bit_packed_detection_event_data.ndim() != 2) throw std::invalid_argument("Input must be 2D."); |
| 31 | + const uint64_t num_detector_bytes = (num_detectors + 7) / 8; |
| 32 | + if (bit_packed_detection_event_data.shape(1) != (py::ssize_t)num_detector_bytes) throw std::invalid_argument("Wrong shape."); |
| 33 | + |
| 34 | + const size_t num_shots = bit_packed_detection_event_data.shape(0); |
| 35 | + const uint64_t num_observable_bytes = (num_observables + 7) / 8; |
| 36 | + |
| 37 | + auto result_array = py::array_t<uint8_t>({(py::ssize_t)num_shots, (py::ssize_t)num_observable_bytes}); |
| 38 | + auto result_buffer = result_array.mutable_data(); |
| 39 | + |
| 40 | + const uint8_t* detections_data = bit_packed_detection_event_data.data(); |
| 41 | + const size_t detections_stride = bit_packed_detection_event_data.strides(0); |
| 42 | + |
| 43 | + for (size_t shot = 0; shot < num_shots; ++shot) { |
| 44 | + const uint8_t* single_shot_data = detections_data + shot * detections_stride; |
| 45 | + std::vector<uint64_t> detections; |
| 46 | + for (uint64_t i = 0; i < num_detectors; ++i) { |
| 47 | + if ((single_shot_data[i / 8] >> (i % 8)) & 1) detections.push_back(i); |
| 48 | + } |
| 49 | + |
| 50 | + std::vector<int> predictions = decoder->decode(detections); |
| 51 | + uint8_t* single_result_buffer = result_buffer + shot * num_observable_bytes; |
| 52 | + std::fill(single_result_buffer, single_result_buffer + num_observable_bytes, 0); |
| 53 | + for (int obs_index : predictions) { |
| 54 | + if (obs_index >= 0 && (uint64_t)obs_index < num_observables) { |
| 55 | + single_result_buffer[obs_index / 8] ^= (1 << (obs_index % 8)); |
| 56 | + } |
| 57 | + } |
| 58 | + } |
| 59 | + return result_array; |
| 60 | + } |
| 61 | +}; |
| 62 | + |
| 63 | +struct MultiPassSinterDecoder { |
| 64 | + size_t num_passes; |
| 65 | + py::object full_decomposer; |
| 66 | + py::object detector_classifier; |
| 67 | + TesseractConfig base_config; |
| 68 | + size_t num_det_orders; |
| 69 | + ::DetOrder det_order_method; |
| 70 | + uint64_t seed; |
| 71 | + SchedulingStrategy strategy; |
| 72 | + |
| 73 | + MultiPassSinterDecoder(size_t n=2) : num_passes(n), full_decomposer(py::none()), detector_classifier(py::none()), num_det_orders(1), det_order_method(::DetOrder::DetBFS), seed(0), strategy(SchedulingStrategy::Static) {} |
| 74 | + |
| 75 | + MultiPassSinterCompiledDecoder compile_decoder_for_dem(const py::object& dem) { |
| 76 | + stim::DetectorErrorModel stim_dem; |
| 77 | + |
| 78 | + if (!full_decomposer.is_none()) { |
| 79 | + py::gil_scoped_acquire acquire; |
| 80 | + py::object decomposed_py_dem = full_decomposer(dem); |
| 81 | + stim_dem = stim::DetectorErrorModel(py::cast<std::string>(py::str(decomposed_py_dem)).c_str()); |
| 82 | + } else { |
| 83 | + stim_dem = stim::DetectorErrorModel(py::cast<std::string>(py::str(dem)).c_str()); |
| 84 | + } |
| 85 | + |
| 86 | + std::vector<int> classification; |
| 87 | + if (py::isinstance<py::function>(detector_classifier)) { |
| 88 | + uint64_t num_dets = stim_dem.count_detectors(); |
| 89 | + |
| 90 | + std::set<uint64_t> detector_ids; |
| 91 | + std::map<uint64_t, std::string> tags; |
| 92 | + for (const auto& inst : stim_dem.flattened().instructions) { |
| 93 | + if (inst.type == stim::DemInstructionType::DEM_DETECTOR) { |
| 94 | + uint64_t d = inst.target_data[0].val(); |
| 95 | + detector_ids.insert(d); |
| 96 | + tags[d] = inst.tag; |
| 97 | + } |
| 98 | + } |
| 99 | + auto coords_map = stim_dem.get_detector_coordinates(detector_ids); |
| 100 | + |
| 101 | + for (uint64_t i = 0; i < num_dets; ++i) { |
| 102 | + std::vector<double> c = coords_map.count(i) ? coords_map.at(i) : std::vector<double>{}; |
| 103 | + std::string t = tags.count(i) ? tags.at(i) : ""; |
| 104 | + py::gil_scoped_acquire acquire; |
| 105 | + classification.push_back(py::cast<int>(detector_classifier((int)i, c, t))); |
| 106 | + } |
| 107 | + } |
| 108 | + |
| 109 | + tesseract::DetectorClassifier classifier = [classification](int index, const std::vector<double>& coords, const std::string& tag) -> int { |
| 110 | + if (index >= 0 && (size_t)index < classification.size()) return classification[index]; |
| 111 | + return 0; |
| 112 | + }; |
| 113 | + |
| 114 | + auto decoder = std::make_unique<tesseract::MultiPassTesseractDecoder>(stim_dem, num_passes, classifier, base_config, num_det_orders, det_order_method, seed, strategy); |
| 115 | + |
| 116 | + return MultiPassSinterCompiledDecoder(std::move(decoder), stim_dem.count_detectors(), stim_dem.count_observables()); |
| 117 | + } |
| 118 | +}; |
| 119 | + |
| 120 | +void pybind_multi_pass_sinter_compat(py::module& m) { |
| 121 | + py::enum_<SchedulingStrategy>(m, "SchedulingStrategy") |
| 122 | + .value("Static", SchedulingStrategy::Static) |
| 123 | + .value("Causal", SchedulingStrategy::Causal) |
| 124 | + .export_values(); |
| 125 | + |
| 126 | + py::class_<MultiPassSinterCompiledDecoder>(m, "MultiPassSinterCompiledDecoder") |
| 127 | + .def_property_readonly("num_components", &MultiPassSinterCompiledDecoder::num_components) |
| 128 | + .def("decode_shots_bit_packed", &MultiPassSinterCompiledDecoder::decode_shots_bit_packed, |
| 129 | + py::kw_only(), py::arg("bit_packed_detection_event_data")); |
| 130 | + |
| 131 | + py::class_<MultiPassSinterDecoder>(m, "MultiPassSinterDecoder") |
| 132 | + .def(py::init<size_t>(), py::arg("num_passes") = 2) |
| 133 | + .def_readwrite("full_decomposer", &MultiPassSinterDecoder::full_decomposer) |
| 134 | + .def_readwrite("detector_classifier", &MultiPassSinterDecoder::detector_classifier) |
| 135 | + .def_readwrite("base_config", &MultiPassSinterDecoder::base_config) |
| 136 | + .def_readwrite("num_det_orders", &MultiPassSinterDecoder::num_det_orders) |
| 137 | + .def_readwrite("det_order_method", &MultiPassSinterDecoder::det_order_method) |
| 138 | + .def_readwrite("seed", &MultiPassSinterDecoder::seed) |
| 139 | + .def_readwrite("strategy", &MultiPassSinterDecoder::strategy) |
| 140 | + .def("compile_decoder_for_dem", &MultiPassSinterDecoder::compile_decoder_for_dem, |
| 141 | + py::kw_only(), py::arg("dem")); |
| 142 | +} |
| 143 | + |
| 144 | +} // namespace tesseract |
| 145 | + |
| 146 | +#endif // MULTI_PASS_SINTER_COMPAT_PYBIND_H |
0 commit comments