Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,15 @@ pybind_library(
name = "tesseract_decoder_pybind",
srcs = [
"common.pybind.h",
"utils.pybind.h",
"simplex.pybind.h",
"tesseract.pybind.h",
],
deps = [
":libcommon",
":libutils",
":libsimplex",
":libtesseract",
"@stim_py//:stim_pybind_lib",
],
)
Expand Down
99 changes: 60 additions & 39 deletions src/common.pybind.h
Original file line number Diff line number Diff line change
@@ -1,54 +1,75 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef TESSERACT_COMMON_PY_H
#define TESSERACT_COMMON_PY_H

#include <vector>

#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "src/stim/dem/dem_instruction.pybind.h"
#include "stim/dem/detector_error_model_target.pybind.h"
#include <vector>

#include "common.h"
#include "src/stim/dem/dem_instruction.pybind.h"
#include "stim/dem/detector_error_model_target.pybind.h"

namespace py = pybind11;

void add_common_module(py::module &root)
{
auto m = root.def_submodule("common", "classes commonly used by the decoder");

py::class_<common::Symptom>(m, "Symptom")
.def(py::init<std::vector<int>, common::ObservablesMask>(),
py::arg("detectors") = std::vector<int>(),
py::arg("observables") = 0)
.def_readwrite("detectors", &common::Symptom::detectors)
.def_readwrite("observables", &common::Symptom::observables)
.def("__str__", &common::Symptom::str)
.def(py::self == py::self)
.def(py::self != py::self)
.def("as_dem_instruction_targets", [](common::Symptom s)
{
std::vector<stim_pybind::ExposedDemTarget> ret;
for(auto & t : s.as_dem_instruction_targets()) ret.emplace_back(t);
return ret; });

py::class_<common::Error>(m, "Error")
.def_readwrite("likelihood_cost", &common::Error::likelihood_cost)
.def_readwrite("probability", &common::Error::probability)
.def_readwrite("symptom", &common::Error::symptom)
.def("__str__", &common::Error::str)
.def(py::init<>())
.def(py::init<double, std::vector<int> &, common::ObservablesMask,
std::vector<bool> &>())
.def(py::init<double, double, std::vector<int> &, common::ObservablesMask,
std::vector<bool> &>())
.def(py::init([](stim_pybind::ExposedDemInstruction edi)
{ return new common::Error(edi.as_dem_instruction()); }));

m.def("merge_identical_errors", &common::merge_identical_errors);
m.def("remove_zero_probability_errors", &common::remove_zero_probability_errors);
m.def("dem_from_counts", &common::dem_from_counts);
void add_common_module(py::module &root) {
auto m = root.def_submodule("common", "classes commonly used by the decoder");

py::class_<common::Symptom>(m, "Symptom")
.def(py::init<std::vector<int>, common::ObservablesMask>(),
py::arg("detectors") = std::vector<int>(),
py::arg("observables") = 0)
.def_readwrite("detectors", &common::Symptom::detectors)
.def_readwrite("observables", &common::Symptom::observables)
.def("__str__", &common::Symptom::str)
.def(py::self == py::self)
.def(py::self != py::self)
.def("as_dem_instruction_targets", [](common::Symptom s) {
std::vector<stim_pybind::ExposedDemTarget> ret;
for (auto &t : s.as_dem_instruction_targets()) ret.emplace_back(t);
return ret;
});

py::class_<common::Error>(m, "Error")
.def_readwrite("likelihood_cost", &common::Error::likelihood_cost)
.def_readwrite("probability", &common::Error::probability)
.def_readwrite("symptom", &common::Error::symptom)
.def("__str__", &common::Error::str)
.def(py::init<>())
.def(py::init<double, std::vector<int> &, common::ObservablesMask,
std::vector<bool> &>(),
py::arg("likelihood_cost"), py::arg("detectors"),
py::arg("observables"), py::arg("dets_array"))
.def(py::init<double, double, std::vector<int> &, common::ObservablesMask,
std::vector<bool> &>(),
py::arg("likelihood_cost"), py::arg("probability"),
py::arg("detectors"), py::arg("observables"), py::arg("dets_array"))
.def(py::init([](stim_pybind::ExposedDemInstruction edi) {
return new common::Error(edi.as_dem_instruction());
}),
py::arg("error"));

m.def("merge_identical_errors", &common::merge_identical_errors,
py::arg("dem"));
m.def("remove_zero_probability_errors",
&common::remove_zero_probability_errors, py::arg("dem"));
m.def("dem_from_counts", &common::dem_from_counts, py::arg("orig_dem"),
py::arg("error_counts"), py::arg("num_shots"));
}

#endif
30 changes: 30 additions & 0 deletions src/py/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,36 @@ py_test(
],
)

py_test(
name = "utils_test",
srcs = ["utils_test.py"],
visibility = ["//:__subpackages__"],
deps = [
"@pypi//pytest",
"//src:lib_tesseract_decoder",
],
)

py_test(
name = "simplex_test",
srcs = ["simplex_test.py"],
visibility = ["//:__subpackages__"],
deps = [
"@pypi//pytest",
"//src:lib_tesseract_decoder",
],
)

py_test(
name = "tesseract_test",
srcs = ["tesseract_test.py"],
visibility = ["//:__subpackages__"],
deps = [
"@pypi//pytest",
"//src:lib_tesseract_decoder",
],
)

compile_pip_requirements(
name = "requirements",
src = "requirements.in",
Expand Down
14 changes: 14 additions & 0 deletions src/py/common_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import stim

Expand Down
52 changes: 52 additions & 0 deletions src/py/simplex_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import pytest
import stim

from src import tesseract_decoder

_DETECTOR_ERROR_MODEL = stim.DetectorErrorModel(
"""
error(0.125) D0
error(0.375) D0 D1
error(0.25) D1
"""
)


def test_create_simplex_config():
sc = tesseract_decoder.simplex.SimplexConfig(_DETECTOR_ERROR_MODEL, window_length=5)
assert sc.dem == _DETECTOR_ERROR_MODEL
assert sc.window_length == 5
assert (
str(sc)
== "SimplexConfig(dem=DetectorErrorModel_Object, window_length=5, window_slide_length=0, verbose=0)"
)


def test_create_simplex_decoder():
decoder = tesseract_decoder.simplex.SimplexDecoder(
tesseract_decoder.simplex.SimplexConfig(_DETECTOR_ERROR_MODEL, window_length=5)
)
decoder.init_ilp()
decoder.decode_to_errors([1])
assert decoder.mask_from_errors([1]) == 0
assert decoder.cost_from_errors([2]) == pytest.approx(1.0986123)
assert decoder.decode([1, 2]) == 0


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
59 changes: 59 additions & 0 deletions src/py/tesseract_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import pytest
import stim

from src import tesseract_decoder

_DETECTOR_ERROR_MODEL = stim.DetectorErrorModel(
"""
error(0.125) D0
error(0.375) D0 D1
error(0.25) D1
"""
)


def test_create_config():
assert (
str(tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL))
== "TesseractConfig(dem=DetectorErrorModel_Object, det_beam=65535, no_revisit_dets=0, at_most_two_errors_per_detector=0, verbose=0, pqlimit=18446744073709551615, det_orders=[], det_penalty=0)"
)


def test_create_node():
node = tesseract_decoder.tesseract.Node(dets=["a"])
assert node.dets == ["a"]


def test_create_qnode():
qnode = tesseract_decoder.tesseract.QNode(num_dets=5, errs=[42])
assert qnode.num_dets == 5
assert str(qnode) == "QNode(cost=0, num_dets=5, errs=[42])"


def test_create_decoder():
config = tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL)
decoder = tesseract_decoder.tesseract.TesseractDecoder(config)
decoder.decode_to_errors([0])
decoder.decode_to_errors([0], 0)
assert decoder.mask_from_errors([1]) == 0
assert decoder.cost_from_errors([1]) == pytest.approx(1.609438)
assert decoder.decode([0]) == 0


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
58 changes: 58 additions & 0 deletions src/py/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import pytest
import stim

from src import tesseract_decoder


_DETECTOR_ERROR_MODEL = stim.DetectorErrorModel(
"""
error(0.125) D0
error(0.375) D0 D1
error(0.25) D1
"""
)


def test_module_has_global_constants():
assert tesseract_decoder.utils.EPSILON <= 1e-7
assert not math.isfinite(tesseract_decoder.utils.INF)


def test_get_detector_coords():
assert tesseract_decoder.utils.get_detector_coords(_DETECTOR_ERROR_MODEL) == []


def test_build_detector_graph():
assert tesseract_decoder.utils.build_detector_graph(_DETECTOR_ERROR_MODEL) == [
[1],
[0],
]


def test_get_errors_from_dem():
expected = "Error{cost=1.945910, symptom=Symptom{D0 }}, Error{cost=0.510826, symptom=Symptom{D0 D1 }}, Error{cost=1.098612, symptom=Symptom{D1 }}"
assert (
", ".join(
map(str, tesseract_decoder.utils.get_errors_from_dem(_DETECTOR_ERROR_MODEL))
)
== expected
)


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
11 changes: 11 additions & 0 deletions src/simplex.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@

constexpr size_t T_COORD = 2;

std::string SimplexConfig::str() {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! thanks for adding this!

auto & self = *this;
std::stringstream ss;
ss << "SimplexConfig(";
ss << "dem=" << "DetectorErrorModel_Object" << ", ";
ss << "window_length=" << self.window_length << ", ";
ss << "window_slide_length=" << self.window_slide_length << ", ";
ss << "verbose=" << self.verbose << ")";
return ss.str();
}

SimplexDecoder::SimplexDecoder(SimplexConfig _config) : config(_config) {
config.dem = common::remove_zero_probability_errors(config.dem);
std::vector<double> detector_t_coords(config.dem.count_detectors());
Expand Down
1 change: 1 addition & 0 deletions src/simplex.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ struct SimplexConfig {
size_t window_slide_length = 0;
bool verbose = false;
bool windowing_enabled() { return (window_length != 0); }
std::string str();
};

struct SimplexDecoder {
Expand Down
Loading