diff --git a/src/BUILD b/src/BUILD index ecbc22ad..a222eaf6 100644 --- a/src/BUILD +++ b/src/BUILD @@ -68,9 +68,15 @@ pybind_library( name = "tesseract_decoder_pybind", srcs = [ "common.pybind.h", + "utils.pybind.h", + "simplex.pybind.h", + "tesseract.pybind.h", ], deps = [ ":libcommon", + ":libutils", + ":libsimplex", + ":libtesseract", "@stim_py//:stim_pybind_lib", ], ) diff --git a/src/common.pybind.h b/src/common.pybind.h index 791a9ed0..04c4e404 100644 --- a/src/common.pybind.h +++ b/src/common.pybind.h @@ -1,54 +1,75 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #ifndef TESSERACT_COMMON_PY_H #define TESSERACT_COMMON_PY_H -#include - #include #include #include -#include "src/stim/dem/dem_instruction.pybind.h" -#include "stim/dem/detector_error_model_target.pybind.h" +#include #include "common.h" +#include "src/stim/dem/dem_instruction.pybind.h" +#include "stim/dem/detector_error_model_target.pybind.h" namespace py = pybind11; -void add_common_module(py::module &root) -{ - auto m = root.def_submodule("common", "classes commonly used by the decoder"); - - py::class_(m, "Symptom") - .def(py::init, common::ObservablesMask>(), - py::arg("detectors") = std::vector(), - py::arg("observables") = 0) - .def_readwrite("detectors", &common::Symptom::detectors) - .def_readwrite("observables", &common::Symptom::observables) - .def("__str__", &common::Symptom::str) - .def(py::self == py::self) - .def(py::self != py::self) - .def("as_dem_instruction_targets", [](common::Symptom s) - { - std::vector ret; - for(auto & t : s.as_dem_instruction_targets()) ret.emplace_back(t); - return ret; }); - - py::class_(m, "Error") - .def_readwrite("likelihood_cost", &common::Error::likelihood_cost) - .def_readwrite("probability", &common::Error::probability) - .def_readwrite("symptom", &common::Error::symptom) - .def("__str__", &common::Error::str) - .def(py::init<>()) - .def(py::init &, common::ObservablesMask, - std::vector &>()) - .def(py::init &, common::ObservablesMask, - std::vector &>()) - .def(py::init([](stim_pybind::ExposedDemInstruction edi) - { return new common::Error(edi.as_dem_instruction()); })); - - m.def("merge_identical_errors", &common::merge_identical_errors); - m.def("remove_zero_probability_errors", &common::remove_zero_probability_errors); - m.def("dem_from_counts", &common::dem_from_counts); +void add_common_module(py::module &root) { + auto m = root.def_submodule("common", "classes commonly used by the decoder"); + + py::class_(m, "Symptom") + .def(py::init, common::ObservablesMask>(), + py::arg("detectors") = std::vector(), + py::arg("observables") = 0) + .def_readwrite("detectors", &common::Symptom::detectors) + .def_readwrite("observables", &common::Symptom::observables) + .def("__str__", &common::Symptom::str) + .def(py::self == py::self) + .def(py::self != py::self) + .def("as_dem_instruction_targets", [](common::Symptom s) { + std::vector ret; + for (auto &t : s.as_dem_instruction_targets()) ret.emplace_back(t); + return ret; + }); + + py::class_(m, "Error") + .def_readwrite("likelihood_cost", &common::Error::likelihood_cost) + .def_readwrite("probability", &common::Error::probability) + .def_readwrite("symptom", &common::Error::symptom) + .def("__str__", &common::Error::str) + .def(py::init<>()) + .def(py::init &, common::ObservablesMask, + std::vector &>(), + py::arg("likelihood_cost"), py::arg("detectors"), + py::arg("observables"), py::arg("dets_array")) + .def(py::init &, common::ObservablesMask, + std::vector &>(), + py::arg("likelihood_cost"), py::arg("probability"), + py::arg("detectors"), py::arg("observables"), py::arg("dets_array")) + .def(py::init([](stim_pybind::ExposedDemInstruction edi) { + return new common::Error(edi.as_dem_instruction()); + }), + py::arg("error")); + + m.def("merge_identical_errors", &common::merge_identical_errors, + py::arg("dem")); + m.def("remove_zero_probability_errors", + &common::remove_zero_probability_errors, py::arg("dem")); + m.def("dem_from_counts", &common::dem_from_counts, py::arg("orig_dem"), + py::arg("error_counts"), py::arg("num_shots")); } #endif diff --git a/src/py/BUILD b/src/py/BUILD index 240ad9be..9f24e4c5 100644 --- a/src/py/BUILD +++ b/src/py/BUILD @@ -11,6 +11,36 @@ py_test( ], ) +py_test( + name = "utils_test", + srcs = ["utils_test.py"], + visibility = ["//:__subpackages__"], + deps = [ + "@pypi//pytest", + "//src:lib_tesseract_decoder", + ], +) + +py_test( + name = "simplex_test", + srcs = ["simplex_test.py"], + visibility = ["//:__subpackages__"], + deps = [ + "@pypi//pytest", + "//src:lib_tesseract_decoder", + ], +) + +py_test( + name = "tesseract_test", + srcs = ["tesseract_test.py"], + visibility = ["//:__subpackages__"], + deps = [ + "@pypi//pytest", + "//src:lib_tesseract_decoder", + ], +) + compile_pip_requirements( name = "requirements", src = "requirements.in", diff --git a/src/py/common_test.py b/src/py/common_test.py index a39c3c2c..b059a412 100644 --- a/src/py/common_test.py +++ b/src/py/common_test.py @@ -1,3 +1,17 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import pytest import stim diff --git a/src/py/simplex_test.py b/src/py/simplex_test.py new file mode 100644 index 00000000..965fe9ea --- /dev/null +++ b/src/py/simplex_test.py @@ -0,0 +1,52 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import pytest +import stim + +from src import tesseract_decoder + +_DETECTOR_ERROR_MODEL = stim.DetectorErrorModel( + """ +error(0.125) D0 +error(0.375) D0 D1 +error(0.25) D1 +""" +) + + +def test_create_simplex_config(): + sc = tesseract_decoder.simplex.SimplexConfig(_DETECTOR_ERROR_MODEL, window_length=5) + assert sc.dem == _DETECTOR_ERROR_MODEL + assert sc.window_length == 5 + assert ( + str(sc) + == "SimplexConfig(dem=DetectorErrorModel_Object, window_length=5, window_slide_length=0, verbose=0)" + ) + + +def test_create_simplex_decoder(): + decoder = tesseract_decoder.simplex.SimplexDecoder( + tesseract_decoder.simplex.SimplexConfig(_DETECTOR_ERROR_MODEL, window_length=5) + ) + decoder.init_ilp() + decoder.decode_to_errors([1]) + assert decoder.mask_from_errors([1]) == 0 + assert decoder.cost_from_errors([2]) == pytest.approx(1.0986123) + assert decoder.decode([1, 2]) == 0 + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__])) diff --git a/src/py/tesseract_test.py b/src/py/tesseract_test.py new file mode 100644 index 00000000..3c06a3f8 --- /dev/null +++ b/src/py/tesseract_test.py @@ -0,0 +1,59 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import pytest +import stim + +from src import tesseract_decoder + +_DETECTOR_ERROR_MODEL = stim.DetectorErrorModel( + """ +error(0.125) D0 +error(0.375) D0 D1 +error(0.25) D1 +""" +) + + +def test_create_config(): + assert ( + str(tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL)) + == "TesseractConfig(dem=DetectorErrorModel_Object, det_beam=65535, no_revisit_dets=0, at_most_two_errors_per_detector=0, verbose=0, pqlimit=18446744073709551615, det_orders=[], det_penalty=0)" + ) + + +def test_create_node(): + node = tesseract_decoder.tesseract.Node(dets=["a"]) + assert node.dets == ["a"] + + +def test_create_qnode(): + qnode = tesseract_decoder.tesseract.QNode(num_dets=5, errs=[42]) + assert qnode.num_dets == 5 + assert str(qnode) == "QNode(cost=0, num_dets=5, errs=[42])" + + +def test_create_decoder(): + config = tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL) + decoder = tesseract_decoder.tesseract.TesseractDecoder(config) + decoder.decode_to_errors([0]) + decoder.decode_to_errors([0], 0) + assert decoder.mask_from_errors([1]) == 0 + assert decoder.cost_from_errors([1]) == pytest.approx(1.609438) + assert decoder.decode([0]) == 0 + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__])) diff --git a/src/py/utils_test.py b/src/py/utils_test.py new file mode 100644 index 00000000..e9d9cfdd --- /dev/null +++ b/src/py/utils_test.py @@ -0,0 +1,58 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import pytest +import stim + +from src import tesseract_decoder + + +_DETECTOR_ERROR_MODEL = stim.DetectorErrorModel( + """ +error(0.125) D0 +error(0.375) D0 D1 +error(0.25) D1 +""" +) + + +def test_module_has_global_constants(): + assert tesseract_decoder.utils.EPSILON <= 1e-7 + assert not math.isfinite(tesseract_decoder.utils.INF) + + +def test_get_detector_coords(): + assert tesseract_decoder.utils.get_detector_coords(_DETECTOR_ERROR_MODEL) == [] + + +def test_build_detector_graph(): + assert tesseract_decoder.utils.build_detector_graph(_DETECTOR_ERROR_MODEL) == [ + [1], + [0], + ] + + +def test_get_errors_from_dem(): + expected = "Error{cost=1.945910, symptom=Symptom{D0 }}, Error{cost=0.510826, symptom=Symptom{D0 D1 }}, Error{cost=1.098612, symptom=Symptom{D1 }}" + assert ( + ", ".join( + map(str, tesseract_decoder.utils.get_errors_from_dem(_DETECTOR_ERROR_MODEL)) + ) + == expected + ) + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__])) diff --git a/src/simplex.cc b/src/simplex.cc index cbe0c941..dfadb51c 100644 --- a/src/simplex.cc +++ b/src/simplex.cc @@ -21,6 +21,17 @@ constexpr size_t T_COORD = 2; +std::string SimplexConfig::str() { + auto & self = *this; + std::stringstream ss; + ss << "SimplexConfig("; + ss << "dem=" << "DetectorErrorModel_Object" << ", "; + ss << "window_length=" << self.window_length << ", "; + ss << "window_slide_length=" << self.window_slide_length << ", "; + ss << "verbose=" << self.verbose << ")"; + return ss.str(); +} + SimplexDecoder::SimplexDecoder(SimplexConfig _config) : config(_config) { config.dem = common::remove_zero_probability_errors(config.dem); std::vector detector_t_coords(config.dem.count_detectors()); diff --git a/src/simplex.h b/src/simplex.h index 94d02d12..91402dc1 100644 --- a/src/simplex.h +++ b/src/simplex.h @@ -30,6 +30,7 @@ struct SimplexConfig { size_t window_slide_length = 0; bool verbose = false; bool windowing_enabled() { return (window_length != 0); } + std::string str(); }; struct SimplexDecoder { diff --git a/src/simplex.pybind.h b/src/simplex.pybind.h new file mode 100644 index 00000000..a69e9684 --- /dev/null +++ b/src/simplex.pybind.h @@ -0,0 +1,66 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _SIMPLEX_PYBIND_H +#define _SIMPLEX_PYBIND_H + +#include +#include +#include + +#include "common.h" +#include "simplex.h" + +namespace py = pybind11; + +void add_simplex_module(py::module &root) { + auto m = root.def_submodule( + "simplex", "Module containing the SimplexDecoder and related methods"); + + py::class_(m, "SimplexConfig") + .def(py::init(), + py::arg("dem"), py::arg("parallelize") = false, + py::arg("window_length") = 0, py::arg("window_slide_length") = 0, + py::arg("verbose") = false) + .def_readwrite("dem", &SimplexConfig::dem) + .def_readwrite("parallelize", &SimplexConfig::parallelize) + .def_readwrite("window_length", &SimplexConfig::window_length) + .def_readwrite("window_slide_length", &SimplexConfig::window_slide_length) + .def_readwrite("verbose", &SimplexConfig::verbose) + .def("windowing_enabled", &SimplexConfig::windowing_enabled) + .def("__str__", &SimplexConfig::str); + + py::class_(m, "SimplexDecoder") + .def(py::init(), py::arg("config")) + .def_readwrite("config", &SimplexDecoder::config) + .def_readwrite("errors", &SimplexDecoder::errors) + .def_readwrite("num_detectors", &SimplexDecoder::num_detectors) + .def_readwrite("num_observables", &SimplexDecoder::num_observables) + .def_readwrite("predicted_errors_buffer", + &SimplexDecoder::predicted_errors_buffer) + .def_readwrite("error_masks", &SimplexDecoder::error_masks) + .def_readwrite("start_time_to_errors", + &SimplexDecoder::start_time_to_errors) + .def_readwrite("end_time_to_errors", &SimplexDecoder::end_time_to_errors) + .def_readonly("low_confidence_flag", &SimplexDecoder::low_confidence_flag) + .def("init_ilp", &SimplexDecoder::init_ilp) + .def("decode_to_errors", &SimplexDecoder::decode_to_errors, + py::arg("detections")) + .def("mask_from_errors", &SimplexDecoder::mask_from_errors, + py::arg("predicted_errors")) + .def("cost_from_errors", &SimplexDecoder::cost_from_errors, + py::arg("predicted_errors")) + .def("decode", &SimplexDecoder::decode, py::arg("detections")); +} +#endif diff --git a/src/tesseract.cc b/src/tesseract.cc index 82acffb3..0c1099ce 100644 --- a/src/tesseract.cc +++ b/src/tesseract.cc @@ -18,10 +18,72 @@ #include #include +namespace +{ + + template + std::ostream &operator<<(std::ostream &os, const std::vector &vec) + { + os << "["; + bool is_first = true; + for (auto &x : vec) + { + if (!is_first) + { + os << ", "; + } + is_first = false; + os << x; + } + os << "]"; + return os; + } + +}; + +std::string TesseractConfig::str() +{ + auto &config = *this; + std::stringstream ss; + ss << "TesseractConfig("; + ss << "dem=DetectorErrorModel_Object" << ", "; + ss << "det_beam=" << config.det_beam << ", "; + ss << "no_revisit_dets=" << config.no_revisit_dets << ", "; + ss << "at_most_two_errors_per_detector=" << config.at_most_two_errors_per_detector << ", "; + ss << "verbose=" << config.verbose << ", "; + ss << "pqlimit=" << config.pqlimit << ", "; + ss << "det_orders=" << config.det_orders << ", "; + ss << "det_penalty=" << config.det_penalty << ")"; + return ss.str(); +} + bool Node::operator>(const Node& other) const { return cost > other.cost || (cost == other.cost && num_dets < other.num_dets); } +std::string Node::str() +{ + std::stringstream ss; + auto &self = *this; + ss << "Node("; + ss << "errs=" << self.errs << ", "; + ss << "dets=" << self.dets << ", "; + ss << "cost=" << self.cost << ", "; + ss << "num_dets=" << self.num_dets << ", "; + ss << "blocked_errs=" << self.blocked_errs << ")"; + return ss.str(); +} + +std::string QNode::str() { + auto & self = *this; + std::stringstream ss; + ss << "QNode("; + ss << "cost=" << self.cost << ", "; + ss << "num_dets=" << self.num_dets << ", "; + ss << "errs=" << self.errs << ")"; + return ss.str(); +} + double TesseractDecoder::get_detcost(size_t d, const std::vector& blocked_errs, const std::vector& det_counts) const { diff --git a/src/tesseract.h b/src/tesseract.h index 0997fdac..cd290eb5 100644 --- a/src/tesseract.h +++ b/src/tesseract.h @@ -37,6 +37,8 @@ struct TesseractConfig { size_t pqlimit = std::numeric_limits::max(); std::vector> det_orders; double det_penalty = 0; + + std::string str(); }; class Node { @@ -48,6 +50,7 @@ class Node { std::vector blocked_errs; bool operator>(const Node& other) const; + std::string str(); }; class QNode { @@ -57,6 +60,7 @@ class QNode { std::vector errs; bool operator>(const QNode& other) const; + std::string str(); }; struct TesseractDecoder { diff --git a/src/tesseract.pybind.cc b/src/tesseract.pybind.cc index e78f1c56..4c35029d 100644 --- a/src/tesseract.pybind.cc +++ b/src/tesseract.pybind.cc @@ -1,10 +1,30 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tesseract.pybind.h" + #include #include "common.pybind.h" #include "pybind11/detail/common.h" +#include "simplex.pybind.h" +#include "utils.pybind.h" -PYBIND11_MODULE(tesseract_decoder, m) -{ - py::module::import("stim"); - add_common_module(m); +PYBIND11_MODULE(tesseract_decoder, tesseract) { + py::module::import("stim"); + add_common_module(tesseract); + add_utils_module(tesseract); + add_simplex_module(tesseract); + add_tesseract_module(tesseract); } diff --git a/src/tesseract.pybind.h b/src/tesseract.pybind.h new file mode 100644 index 00000000..5c57a2e9 --- /dev/null +++ b/src/tesseract.pybind.h @@ -0,0 +1,100 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _TESSERACT_PYBIND_H +#define _TESSERACT_PYBIND_H + +#include +#include +#include + +#include "tesseract.h" + +namespace py = pybind11; + +void add_tesseract_module(py::module &root) { + auto m = root.def_submodule("tesseract", + "Module containing the tesseract algorithm"); + + m.attr("INF_DET_BEAM") = INF_DET_BEAM; + py::class_(m, "TesseractConfig") + .def(py::init>, double>(), + py::arg("dem"), py::arg("det_beam") = INF_DET_BEAM, + py::arg("beam_climbing") = false, py::arg("no_revisit_dets") = false, + py::arg("at_most_two_errors_per_detector") = false, + py::arg("verbose") = false, + py::arg("pqlimit") = std::numeric_limits::max(), + py::arg("det_orders") = std::vector>(), + py::arg("det_penalty") = 0.0) + .def_readwrite("dem", &TesseractConfig::dem) + .def_readwrite("det_beam", &TesseractConfig::det_beam) + .def_readwrite("no_revisit_dets", &TesseractConfig::no_revisit_dets) + .def_readwrite("at_most_two_errors_per_detector", + &TesseractConfig::at_most_two_errors_per_detector) + .def_readwrite("verbose", &TesseractConfig::verbose) + .def_readwrite("pqlimit", &TesseractConfig::pqlimit) + .def_readwrite("det_orders", &TesseractConfig::det_orders) + .def_readwrite("det_penalty", &TesseractConfig::det_penalty) + .def("__str__", &TesseractConfig::str); + + py::class_(m, "Node") + .def(py::init, std::vector, double, size_t, + std::vector>(), + py::arg("errs") = std::vector(), + py::arg("dets") = std::vector(), py::arg("cost") = 0.0, + py::arg("num_dets") = 0, + py::arg("blocked_errs") = std::vector()) + .def_readwrite("errs", &Node::errs) + .def_readwrite("dets", &Node::dets) + .def_readwrite("cost", &Node::cost) + .def_readwrite("num_dets", &Node::num_dets) + .def_readwrite("blocked_errs", &Node::blocked_errs) + .def(py::self > py::self) + .def("__str__", &Node::str); + + py::class_(m, "QNode") + .def(py::init>(), + py::arg("cost") = 0.0, py::arg("num_dets") = 0, + py::arg("errs") = std::vector()) + .def_readwrite("cost", &QNode::cost) + .def_readwrite("num_dets", &QNode::num_dets) + .def_readwrite("errs", &QNode::errs) + .def(py::self > py::self) + .def("__str__", &QNode::str); + + py::class_(m, "TesseractDecoder") + .def(py::init(), py::arg("config")) + .def("decode_to_errors", + py::overload_cast &>( + &TesseractDecoder::decode_to_errors), + py::arg("detections")) + .def("decode_to_errors", + py::overload_cast &, size_t>( + &TesseractDecoder::decode_to_errors), + py::arg("detections"), py::arg("det_order")) + .def("mask_from_errors", &TesseractDecoder::mask_from_errors, + py::arg("predicted_errors")) + .def("cost_from_errors", &TesseractDecoder::cost_from_errors, + py::arg("predicted_errors")) + .def("decode", &TesseractDecoder::decode, py::arg("detections")) + .def_readwrite("low_confidence_flag", + &TesseractDecoder::low_confidence_flag) + .def_readwrite("predicted_errors_buffer", + &TesseractDecoder::predicted_errors_buffer) + .def_readwrite("det_beam", &TesseractDecoder::det_beam) + .def_readwrite("errors", &TesseractDecoder::errors); +} + +#endif diff --git a/src/utils.pybind.h b/src/utils.pybind.h new file mode 100644 index 00000000..14ca72f4 --- /dev/null +++ b/src/utils.pybind.h @@ -0,0 +1,38 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _UTILS_PYBIND_H +#define _UTILS_PYBIND_H + +#include +#include +#include + +#include "utils.h" + +namespace py = pybind11; + +void add_utils_module(py::module &root) { + auto m = root.def_submodule("utils", "utility methods"); + + m.attr("EPSILON") = EPSILON; + m.attr("INF") = INF; + m.def("get_detector_coords", &get_detector_coords, py::arg("dem")); + m.def("build_detector_graph", &build_detector_graph, py::arg("dem")); + m.def("get_errors_from_dem", &get_errors_from_dem, py::arg("dem")); + + // Not exposing sampling_from_dem and sample_shots because they depend on + // stim::SparseShot which stim doesn't expose to python. +} +#endif