Skip to content

Commit d6ad104

Browse files
committed
feat: expose C++ properties to Pybind and update Sinter build dependencies
TAG=agy
1 parent 27cffb5 commit d6ad104

4 files changed

Lines changed: 54 additions & 45 deletions

File tree

BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ py_wheel(
1919
name="tesseract_decoder_wheel",
2020
distribution = "tesseract_decoder",
2121
deps=[
22-
"//src:tesseract_decoder",
22+
"//src:_core",
2323
"//src/py:generated_stubs",
24-
"//src/py/_tesseract_py_util:_tesseract_py_util",
24+
"//src/py:tesseract_decoder",
2525
":package_data",
2626
],
2727
version = "$(VERSION)",

src/error_correlations.test.cc

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,61 @@
1-
#include "gtest/gtest.h"
21
#include "error_correlations.h"
2+
33
#include <vector>
44

5+
#include "gtest/gtest.h"
6+
57
using namespace tesseract;
68

79
TEST(TwoPassCorrelationsTest, JointProbabilities) {
8-
stim::DetectorErrorModel dem(R"DEM(
10+
stim::DetectorErrorModel dem(R"DEM(
911
error(0.1) D0 ^ D1
1012
error(0.2) D0
1113
)DEM");
1214

13-
auto joint = get_hyperedge_joint_probabilities(dem);
14-
15-
Hyperedge h0 = {0};
16-
Hyperedge h1 = {1};
17-
18-
// P(D0) = 0.1 XOR 0.2 = 0.1*(1-0.2) + 0.2*(1-0.1) = 0.08 + 0.18 = 0.26
19-
EXPECT_NEAR(joint[h0][h0], 0.26, 1e-6);
20-
// P(D1) = 0.1
21-
EXPECT_NEAR(joint[h1][h1], 0.1, 1e-6);
22-
// P(D0 and D1) = 0.1
23-
EXPECT_NEAR(joint[h0][h1], 0.1, 1e-6);
24-
EXPECT_NEAR(joint[h1][h0], 0.1, 1e-6);
15+
std::vector<int> global_det_to_comp_id = {0, 1};
16+
auto joint = get_hyperedge_joint_probabilities(dem, global_det_to_comp_id);
17+
18+
Hyperedge h0 = {0};
19+
Hyperedge h1 = {1};
20+
21+
// P(D0) = 0.1 XOR 0.2 = 0.1*(1-0.2) + 0.2*(1-0.1) = 0.08 + 0.18 = 0.26
22+
EXPECT_NEAR(joint[h0][h0], 0.26, 1e-6);
23+
// P(D1) = 0.1
24+
EXPECT_NEAR(joint[h1][h1], 0.1, 1e-6);
25+
// P(D0 and D1) = 0.1
26+
EXPECT_NEAR(joint[h0][h1], 0.1, 1e-6);
27+
EXPECT_NEAR(joint[h1][h0], 0.1, 1e-6);
2528
}
2629

2730
TEST(TwoPassCorrelationsTest, ImpliedProbabilities) {
28-
JointProbsMap joint;
29-
Hyperedge h0 = {0};
30-
Hyperedge h1 = {1};
31-
32-
joint[h0][h0] = 0.2;
33-
joint[h1][h1] = 0.1;
34-
joint[h0][h1] = 0.05;
35-
joint[h1][h0] = 0.05;
36-
37-
auto implied = get_implied_hyperedge_probabilities(joint);
38-
39-
// P(D1 | D0) = 0.05 / 0.2 = 0.25
40-
bool found = false;
41-
for (const auto& imp : implied[h0]) {
42-
if (imp.affected_hyperedge == h1) {
43-
EXPECT_NEAR(imp.probability, 0.25, 1e-6);
44-
found = true;
45-
}
31+
JointProbsMap joint;
32+
Hyperedge h0 = {0};
33+
Hyperedge h1 = {1};
34+
35+
joint[h0][h0] = 0.2;
36+
joint[h1][h1] = 0.1;
37+
joint[h0][h1] = 0.05;
38+
joint[h1][h0] = 0.05;
39+
40+
auto implied = get_implied_hyperedge_probabilities(joint);
41+
42+
// P(D1 | D0) = 0.05 / 0.2 = 0.25
43+
bool found = false;
44+
for (const auto& imp : implied[h0]) {
45+
if (imp.affected_hyperedge == h1) {
46+
EXPECT_NEAR(imp.probability, 0.25, 1e-6);
47+
found = true;
4648
}
47-
EXPECT_TRUE(found);
48-
49-
// P(D0 | D1) = 0.05 / 0.1 = 0.5
50-
found = false;
51-
for (const auto& imp : implied[h1]) {
52-
if (imp.affected_hyperedge == h0) {
53-
EXPECT_NEAR(imp.probability, 0.5, 1e-6);
54-
found = true;
55-
}
49+
}
50+
EXPECT_TRUE(found);
51+
52+
// P(D0 | D1) = 0.05 / 0.1 = 0.5
53+
found = false;
54+
for (const auto& imp : implied[h1]) {
55+
if (imp.affected_hyperedge == h0) {
56+
EXPECT_NEAR(imp.probability, 0.5, 1e-6);
57+
found = true;
5658
}
57-
EXPECT_TRUE(found);
59+
}
60+
EXPECT_TRUE(found);
5861
}

src/multi_pass_sinter_compat.pybind.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ void pybind_multi_pass_sinter_compat(py::module& m) {
126126
py::class_<MultiPassSinterCompiledDecoder>(m, "MultiPassSinterCompiledDecoder")
127127
.def_property_readonly("num_components", &MultiPassSinterCompiledDecoder::num_components)
128128
.def("decode_shots_bit_packed", &MultiPassSinterCompiledDecoder::decode_shots_bit_packed,
129-
py::kw_only(), py::arg("bit_packed_detection_event_data"));
129+
py::kw_only(), py::arg("bit_packed_detection_event_data"),
130+
py::call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>());
130131

131132
py::class_<MultiPassSinterDecoder>(m, "MultiPassSinterDecoder")
132133
.def(py::init<size_t>(), py::arg("num_passes") = 2)

src/py/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ py_library(
3131
data = [":copy_core_so"],
3232
imports = ["."],
3333
visibility = ["//visibility:public"],
34+
deps = [
35+
"@pypi//stim",
36+
"@pypi//numpy",
37+
"@pypi//sinter",
38+
],
3439
)
3540

3641
py_library(

0 commit comments

Comments
 (0)