Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,15 @@ pybind_library(
name = "tesseract_decoder_pybind",
srcs = [
"common.pybind.h",
"utils.pybind.h",
"simplex.pybind.h",
"tesseract.pybind.h",
],
deps = [
":libcommon",
":libutils",
":libsimplex",
":libtesseract",
"@stim_py//:stim_pybind_lib",
],
)
Expand Down
15 changes: 9 additions & 6 deletions src/common.pybind.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,18 @@ void add_common_module(py::module &root)
.def("__str__", &common::Error::str)
.def(py::init<>())
.def(py::init<double, std::vector<int> &, common::ObservablesMask,
std::vector<bool> &>())
std::vector<bool> &>(),
py::arg("likelihood_cost"), py::arg("detectors"), py::arg("observables"), py::arg("dets_array"))
.def(py::init<double, double, std::vector<int> &, common::ObservablesMask,
std::vector<bool> &>())
std::vector<bool> &>(),
py::arg("likelihood_cost"), py::arg("probability"), py::arg("detectors"), py::arg("observables"), py::arg("dets_array"))
.def(py::init([](stim_pybind::ExposedDemInstruction edi)
{ return new common::Error(edi.as_dem_instruction()); }));
{ return new common::Error(edi.as_dem_instruction()); }),
py::arg("error"));

m.def("merge_identical_errors", &common::merge_identical_errors);
m.def("remove_zero_probability_errors", &common::remove_zero_probability_errors);
m.def("dem_from_counts", &common::dem_from_counts);
m.def("merge_identical_errors", &common::merge_identical_errors, py::arg("dem"));
m.def("remove_zero_probability_errors", &common::remove_zero_probability_errors, py::arg("dem"));
m.def("dem_from_counts", &common::dem_from_counts, py::arg("orig_dem"), py::arg("error_counts"), py::arg("num_shots"));
}

#endif
30 changes: 30 additions & 0 deletions src/py/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,36 @@ py_test(
],
)

py_test(
name = "utils_test",
srcs = ["utils_test.py"],
visibility = ["//:__subpackages__"],
deps = [
"@pypi//pytest",
"//src:lib_tesseract_decoder",
],
)

py_test(
name = "simplex_test",
srcs = ["simplex_test.py"],
visibility = ["//:__subpackages__"],
deps = [
"@pypi//pytest",
"//src:lib_tesseract_decoder",
],
)

py_test(
name = "tesseract_test",
srcs = ["tesseract_test.py"],
visibility = ["//:__subpackages__"],
deps = [
"@pypi//pytest",
"//src:lib_tesseract_decoder",
],
)

compile_pip_requirements(
name = "requirements",
src = "requirements.in",
Expand Down
38 changes: 38 additions & 0 deletions src/py/simplex_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import math
import pytest
import stim

from src import tesseract_decoder

_DETECTOR_ERROR_MODEL = stim.DetectorErrorModel(
"""
error(0.125) D0
error(0.375) D0 D1
error(0.25) D1
"""
)


def test_create_simplex_config():
sc = tesseract_decoder.simplex.SimplexConfig(_DETECTOR_ERROR_MODEL, window_length=5)
assert sc.dem == _DETECTOR_ERROR_MODEL
assert sc.window_length == 5
assert (
str(sc)
== "SimplexConfig(dem=DetectorErrorModel_Object, window_length=5, window_slide_length=0, verbose=0)"
)


def test_create_simplex_decoder():
decoder = tesseract_decoder.simplex.SimplexDecoder(
tesseract_decoder.simplex.SimplexConfig(_DETECTOR_ERROR_MODEL, window_length=5)
)
decoder.init_ilp()
decoder.decode_to_errors([1])
assert decoder.mask_from_errors([1]) == 0
assert decoder.cost_from_errors([2]) == pytest.approx(1.0986123)
assert decoder.decode([1, 2]) == 0


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
45 changes: 45 additions & 0 deletions src/py/tesseract_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import math
import pytest
import stim

from src import tesseract_decoder

_DETECTOR_ERROR_MODEL = stim.DetectorErrorModel(
"""
error(0.125) D0
error(0.375) D0 D1
error(0.25) D1
"""
)


def test_create_config():
assert (
str(tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL))
== "TesseractConfig(dem=DetectorErrorModel_Object, det_beam=65535, no_revisit_dets=0, at_most_two_errors_per_detector=0, verbose=0, pqlimit=18446744073709551615, det_orders=[], det_penalty=0)"
)


def test_create_node():
node = tesseract_decoder.tesseract.Node(dets=["a"])
assert node.dets == ["a"]


def test_create_qnode():
qnode = tesseract_decoder.tesseract.QNode(num_dets=5, errs=[42])
assert qnode.num_dets == 5
assert str(qnode) == "QNode(cost=0, num_dets=5, errs=[42])"


def test_create_decoder():
config = tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL)
decoder = tesseract_decoder.tesseract.TesseractDecoder(config)
decoder.decode_to_errors([0])
decoder.decode_to_errors([0], 0)
assert decoder.mask_from_errors([1]) == 0
assert decoder.cost_from_errors([1]) == pytest.approx(1.609438)
assert decoder.decode([0]) == 0


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
44 changes: 44 additions & 0 deletions src/py/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import math
import pytest
import stim

from src import tesseract_decoder


_DETECTOR_ERROR_MODEL = stim.DetectorErrorModel(
"""
error(0.125) D0
error(0.375) D0 D1
error(0.25) D1
"""
)


def test_module_has_global_constants():
assert tesseract_decoder.utils.EPSILON <= 1e-7
assert not math.isfinite(tesseract_decoder.utils.INF)


def test_get_detector_coords():
assert tesseract_decoder.utils.get_detector_coords(_DETECTOR_ERROR_MODEL) == []


def test_build_detector_graph():
assert tesseract_decoder.utils.build_detector_graph(_DETECTOR_ERROR_MODEL) == [
[1],
[0],
]


def test_get_errors_from_dem():
expected = "Error{cost=1.945910, symptom=Symptom{D0 }}, Error{cost=0.510826, symptom=Symptom{D0 D1 }}, Error{cost=1.098612, symptom=Symptom{D1 }}"
assert (
", ".join(
map(str, tesseract_decoder.utils.get_errors_from_dem(_DETECTOR_ERROR_MODEL))
)
== expected
)


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
11 changes: 11 additions & 0 deletions src/simplex.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@

constexpr size_t T_COORD = 2;

std::string SimplexConfig::str() {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! thanks for adding this!

auto & self = *this;
std::stringstream ss;
ss << "SimplexConfig(";
ss << "dem=" << "DetectorErrorModel_Object" << ", ";
ss << "window_length=" << self.window_length << ", ";
ss << "window_slide_length=" << self.window_slide_length << ", ";
ss << "verbose=" << self.verbose << ")";
return ss.str();
}

SimplexDecoder::SimplexDecoder(SimplexConfig _config) : config(_config) {
config.dem = common::remove_zero_probability_errors(config.dem);
std::vector<double> detector_t_coords(config.dem.count_detectors());
Expand Down
1 change: 1 addition & 0 deletions src/simplex.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ struct SimplexConfig {
size_t window_slide_length = 0;
bool verbose = false;
bool windowing_enabled() { return (window_length != 0); }
std::string str();
};

struct SimplexDecoder {
Expand Down
50 changes: 50 additions & 0 deletions src/simplex.pybind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#ifndef _SIMPLEX_PYBIND_H
#define _SIMPLEX_PYBIND_H

#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "common.h"
#include "simplex.h"

namespace py = pybind11;

void add_simplex_module(py::module &root)
{
auto m = root.def_submodule(
"simplex", "Module containing the SimplexDecoder and related methods");

py::class_<SimplexConfig>(m, "SimplexConfig")
.def(py::init<stim::DetectorErrorModel, bool, size_t, size_t, bool>(),
py::arg("dem"), py::arg("parallelize") = false,
py::arg("window_length") = 0, py::arg("window_slide_length") = 0,
py::arg("verbose") = false)
.def_readwrite("dem", &SimplexConfig::dem)
.def_readwrite("parallelize", &SimplexConfig::parallelize)
.def_readwrite("window_length", &SimplexConfig::window_length)
.def_readwrite("window_slide_length", &SimplexConfig::window_slide_length)
.def_readwrite("verbose", &SimplexConfig::verbose)
.def("windowing_enabled", &SimplexConfig::windowing_enabled)
.def("__str__", &SimplexConfig::str);

py::class_<SimplexDecoder>(m, "SimplexDecoder")
.def(py::init<SimplexConfig>(), py::arg("config"))
.def_readwrite("config", &SimplexDecoder::config)
.def_readwrite("errors", &SimplexDecoder::errors)
.def_readwrite("num_detectors", &SimplexDecoder::num_detectors)
.def_readwrite("num_observables", &SimplexDecoder::num_observables)
.def_readwrite("predicted_errors_buffer",
&SimplexDecoder::predicted_errors_buffer)
.def_readwrite("error_masks", &SimplexDecoder::error_masks)
.def_readwrite("start_time_to_errors",
&SimplexDecoder::start_time_to_errors)
.def_readwrite("end_time_to_errors", &SimplexDecoder::end_time_to_errors)
.def_readonly("low_confidence_flag", &SimplexDecoder::low_confidence_flag)
.def("init_ilp", &SimplexDecoder::init_ilp)
.def("decode_to_errors", &SimplexDecoder::decode_to_errors, py::arg("detections"))
.def("mask_from_errors", &SimplexDecoder::mask_from_errors, py::arg("predicted_errors"))
.def("cost_from_errors", &SimplexDecoder::cost_from_errors, py::arg("predicted_errors"))
.def("decode", &SimplexDecoder::decode, py::arg("detections"));
}
#endif
62 changes: 62 additions & 0 deletions src/tesseract.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,72 @@
#include <cassert>
#include <iostream>

namespace
{

template <typename T>
std::ostream &operator<<(std::ostream &os, const std::vector<T> &vec)
{
os << "[";
bool is_first = true;
for (auto &x : vec)
{
if (!is_first)
{
os << ", ";
}
is_first = false;
os << x;
}
os << "]";
return os;
}

};

std::string TesseractConfig::str()
{
auto &config = *this;
std::stringstream ss;
ss << "TesseractConfig(";
ss << "dem=DetectorErrorModel_Object" << ", ";
ss << "det_beam=" << config.det_beam << ", ";
ss << "no_revisit_dets=" << config.no_revisit_dets << ", ";
ss << "at_most_two_errors_per_detector=" << config.at_most_two_errors_per_detector << ", ";
ss << "verbose=" << config.verbose << ", ";
ss << "pqlimit=" << config.pqlimit << ", ";
ss << "det_orders=" << config.det_orders << ", ";
ss << "det_penalty=" << config.det_penalty << ")";
return ss.str();
}

bool Node::operator>(const Node& other) const {
return cost > other.cost || (cost == other.cost && num_dets < other.num_dets);
}

std::string Node::str()
{
std::stringstream ss;
auto &self = *this;
ss << "Node(";
ss << "errs=" << self.errs << ", ";
ss << "dets=" << self.dets << ", ";
ss << "cost=" << self.cost << ", ";
ss << "num_dets=" << self.num_dets << ", ";
ss << "blocked_errs=" << self.blocked_errs << ")";
return ss.str();
}

std::string QNode::str() {
auto & self = *this;
std::stringstream ss;
ss << "QNode(";
ss << "cost=" << self.cost << ", ";
ss << "num_dets=" << self.num_dets << ", ";
ss << "errs=" << self.errs << ")";
return ss.str();
}

double TesseractDecoder::get_detcost(size_t d,
const std::vector<char>& blocked_errs,
const std::vector<size_t>& det_counts) const {
Expand Down
4 changes: 4 additions & 0 deletions src/tesseract.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ struct TesseractConfig {
size_t pqlimit = std::numeric_limits<size_t>::max();
std::vector<std::vector<size_t>> det_orders;
double det_penalty = 0;

std::string str();
};

class Node {
Expand All @@ -48,6 +50,7 @@ class Node {
std::vector<char> blocked_errs;

bool operator>(const Node& other) const;
std::string str();
};

class QNode {
Expand All @@ -57,6 +60,7 @@ class QNode {
std::vector<size_t> errs;

bool operator>(const QNode& other) const;
std::string str();
};

struct TesseractDecoder {
Expand Down
Loading