Skip to content

Commit 260799e

Browse files
authored
Merge pull request #33 from noajshu/codex/refactor-tesseractconfig-and-simplexconfig-for-verbose-callb
feat: add verbose callback for decoders
2 parents 1de3cad + d4d0040 commit 260799e

13 files changed

Lines changed: 260 additions & 106 deletions

src/py/common_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pytest
1616
import stim
1717

18-
from src import tesseract_decoder
18+
import tesseract_decoder
1919

2020
def get_set_bits(n):
2121
"""

src/py/simplex_test.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import pytest
1717
import stim
1818

19-
from src import tesseract_decoder
19+
import tesseract_decoder
2020

2121
_DETECTOR_ERROR_MODEL = stim.DetectorErrorModel(
2222
"""
@@ -33,7 +33,7 @@ def test_create_simplex_config():
3333
assert sc.window_length == 5
3434
assert (
3535
str(sc)
36-
== "SimplexConfig(dem=DetectorErrorModel_Object, window_length=5, window_slide_length=0, verbose=0)"
36+
== "SimplexConfig(dem=DetectorErrorModel_Object, window_length=5, window_slide_length=0)"
3737
)
3838

3939

@@ -46,6 +46,20 @@ def test_create_simplex_decoder():
4646
assert decoder.cost_from_errors([2]) == pytest.approx(1.0986123)
4747
assert decoder.decode([1]) == []
4848

49+
50+
def test_simplex_verbose_callback_receives_output():
51+
lines = []
52+
53+
def cb(s: str) -> None:
54+
lines.append(s)
55+
56+
config = tesseract_decoder.simplex.SimplexConfig(
57+
_DETECTOR_ERROR_MODEL, window_length=5, verbose_callback=cb
58+
)
59+
decoder = tesseract_decoder.simplex.SimplexDecoder(config)
60+
decoder.decode_to_errors([1])
61+
assert any(lines)
62+
4963
def test_simplex_decoder_predicts_various_observable_flips():
5064
"""
5165
Tests that the Simplex decoder correctly predicts a logical observable

src/py/tesseract_test.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import pytest
1717
import stim
1818

19-
from src import tesseract_decoder
19+
import tesseract_decoder
2020

2121
_DETECTOR_ERROR_MODEL = stim.DetectorErrorModel(
2222
"""
@@ -30,7 +30,7 @@
3030
def test_create_config():
3131
assert (
3232
str(tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL))
33-
== "TesseractConfig(dem=DetectorErrorModel_Object, det_beam=65535, no_revisit_dets=0, at_most_two_errors_per_detector=0, verbose=0, pqlimit=18446744073709551615, det_orders=[], det_penalty=0)"
33+
== "TesseractConfig(dem=DetectorErrorModel_Object, det_beam=65535, no_revisit_dets=0, at_most_two_errors_per_detector=0, pqlimit=18446744073709551615, det_orders=[], det_penalty=0)"
3434
)
3535
assert (
3636
tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL).dem
@@ -52,6 +52,20 @@ def test_create_decoder():
5252
assert decoder.cost_from_errors([1]) == pytest.approx(0.5108256237659907)
5353
assert decoder.decode([0]) == []
5454

55+
56+
def test_tesseract_verbose_callback_receives_output():
57+
lines = []
58+
59+
def cb(s: str) -> None:
60+
lines.append(s)
61+
62+
config = tesseract_decoder.tesseract.TesseractConfig(
63+
_DETECTOR_ERROR_MODEL, verbose_callback=cb
64+
)
65+
decoder = tesseract_decoder.tesseract.TesseractDecoder(config)
66+
decoder.decode_to_errors([0])
67+
assert any(lines)
68+
5569
def test_tesseract_decoder_predicts_various_observable_flips():
5670
"""
5771
Tests that the Tesseract decoder correctly predicts a logical observable

src/py/utils_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import pytest
1717
import stim
1818

19-
from src import tesseract_decoder
19+
import tesseract_decoder
2020

2121

2222
_DETECTOR_ERROR_MODEL = stim.DetectorErrorModel(

src/simplex.cc

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,36 @@
1515
#include "simplex.h"
1616

1717
#include <cassert>
18+
#include <iostream>
1819

1920
#include "Highs.h"
2021
#include "io/HMPSIO.h"
2122

2223
constexpr size_t T_COORD = 2;
2324

25+
namespace {
26+
void highs_log_cb(HighsLogType, const char* msg, void* user_data) {
27+
CallbackStream* stream = static_cast<CallbackStream*>(user_data);
28+
(*stream) << msg;
29+
stream->flush();
30+
}
31+
} // namespace
32+
2433
std::string SimplexConfig::str() {
2534
auto& self = *this;
2635
std::stringstream ss;
2736
ss << "SimplexConfig(";
2837
ss << "dem=" << "DetectorErrorModel_Object" << ", ";
2938
ss << "window_length=" << self.window_length << ", ";
30-
ss << "window_slide_length=" << self.window_slide_length << ", ";
31-
ss << "verbose=" << self.verbose << ")";
39+
ss << "window_slide_length=" << self.window_slide_length << ")";
3240
return ss.str();
3341
}
3442

3543
SimplexDecoder::SimplexDecoder(SimplexConfig _config) : config(_config) {
44+
if (!config.verbose_callback) {
45+
config.verbose_callback = [](const std::string& s) { std::cout << s; };
46+
}
47+
config.log_stream.callback = config.verbose_callback;
3648
config.dem = common::remove_zero_probability_errors(config.dem);
3749
std::vector<double> detector_t_coords(config.dem.count_detectors());
3850
for (const stim::DemInstruction& instruction : config.dem.flattened().instructions) {
@@ -152,7 +164,11 @@ void SimplexDecoder::init_ilp() {
152164
// Disabled presolve entirely after encountering bugs similar to this one:
153165
// https://github.com/ERGO-Code/HiGHS/issues/1273
154166
highs->setOptionValue("presolve", "off");
155-
highs->setOptionValue("output_flag", config.verbose);
167+
highs->setOptionValue("output_flag", config.log_stream.active);
168+
highs->setOptionValue("log_to_console", false);
169+
if (config.log_stream.active) {
170+
highs->setLogCallback(highs_log_cb, &config.log_stream);
171+
}
156172
}
157173

158174
void SimplexDecoder::decode_to_errors(const std::vector<uint64_t>& detections) {
@@ -197,9 +213,7 @@ void SimplexDecoder::decode_to_errors(const std::vector<uint64_t>& detections) {
197213
add_costs_for_time(t1);
198214
++t1;
199215
}
200-
if (config.verbose) {
201-
std::cout << "t0 = " << t0 << " t1 = " << t1 << std::endl;
202-
}
216+
config.log_stream << "t0 = " << t0 << " t1 = " << t1 << std::endl;
203217

204218
// Pass the model to HiGHS
205219
*return_status = highs->passModel(*model);
@@ -235,16 +249,20 @@ void SimplexDecoder::decode_to_errors(const std::vector<uint64_t>& detections) {
235249
}
236250
assert(*return_status == HighsStatus::kOk);
237251

238-
if (config.verbose) {
239-
// Get the solution information
252+
if (config.log_stream.active) {
240253
const HighsInfo& info = highs->getInfo();
241-
std::cout << "Simplex iteration count: " << info.simplex_iteration_count << std::endl;
242-
std::cout << "Objective function value: " << info.objective_function_value << std::endl;
243-
std::cout << "Primal solution status: "
244-
<< highs->solutionStatusToString(info.primal_solution_status) << std::endl;
245-
std::cout << "Dual solution status: "
246-
<< highs->solutionStatusToString(info.dual_solution_status) << std::endl;
247-
std::cout << "Basis: " << highs->basisValidityToString(info.basis_validity) << std::endl;
254+
config.log_stream << "Simplex iteration count: " << info.simplex_iteration_count
255+
<< std::endl;
256+
config.log_stream << "Objective function value: " << info.objective_function_value
257+
<< std::endl;
258+
config.log_stream << "Primal solution status: "
259+
<< highs->solutionStatusToString(info.primal_solution_status)
260+
<< std::endl;
261+
config.log_stream << "Dual solution status: "
262+
<< highs->solutionStatusToString(info.dual_solution_status)
263+
<< std::endl;
264+
config.log_stream << "Basis: "
265+
<< highs->basisValidityToString(info.basis_validity) << std::endl;
248266
}
249267

250268
// Get the model status
@@ -286,16 +304,20 @@ void SimplexDecoder::decode_to_errors(const std::vector<uint64_t>& detections) {
286304
*return_status = highs->run();
287305
assert(*return_status == HighsStatus::kOk);
288306

289-
if (config.verbose) {
290-
// Get the solution information
307+
if (config.log_stream.active) {
291308
const HighsInfo& info = highs->getInfo();
292-
std::cout << "Simplex iteration count: " << info.simplex_iteration_count << std::endl;
293-
std::cout << "Objective function value: " << info.objective_function_value << std::endl;
294-
std::cout << "Primal solution status: "
295-
<< highs->solutionStatusToString(info.primal_solution_status) << std::endl;
296-
std::cout << "Dual solution status: "
297-
<< highs->solutionStatusToString(info.dual_solution_status) << std::endl;
298-
std::cout << "Basis: " << highs->basisValidityToString(info.basis_validity) << std::endl;
309+
config.log_stream << "Simplex iteration count: " << info.simplex_iteration_count
310+
<< std::endl;
311+
config.log_stream << "Objective function value: " << info.objective_function_value
312+
<< std::endl;
313+
config.log_stream << "Primal solution status: "
314+
<< highs->solutionStatusToString(info.primal_solution_status)
315+
<< std::endl;
316+
config.log_stream << "Dual solution status: "
317+
<< highs->solutionStatusToString(info.dual_solution_status)
318+
<< std::endl;
319+
config.log_stream << "Basis: "
320+
<< highs->basisValidityToString(info.basis_validity) << std::endl;
299321
}
300322

301323
// Get the model status

src/simplex.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414

1515
#ifndef SIMPLEX_HPP
1616
#define SIMPLEX_HPP
17+
#include <functional>
1718
#include <unordered_set>
1819
#include <vector>
1920

2021
#include "common.h"
2122
#include "stim.h"
23+
#include "utils.h"
2224

2325
struct HighsModel;
2426
struct Highs;
@@ -29,7 +31,8 @@ struct SimplexConfig {
2931
bool parallelize = false;
3032
size_t window_length = 0;
3133
size_t window_slide_length = 0;
32-
bool verbose = false;
34+
std::function<void(const std::string&)> verbose_callback;
35+
CallbackStream log_stream;
3336
bool windowing_enabled() {
3437
return (window_length != 0);
3538
}

src/simplex.pybind.h

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <pybind11/operators.h>
1919
#include <pybind11/pybind11.h>
2020
#include <pybind11/stl.h>
21+
#include <iostream>
2122

2223
#include "common.h"
2324
#include "simplex.h"
@@ -28,9 +29,26 @@ namespace py = pybind11;
2829
namespace {
2930
SimplexConfig simplex_config_maker(py::object dem, bool parallelize = false,
3031
size_t window_length = 0, size_t window_slide_length = 0,
31-
bool verbose = false) {
32+
py::object verbose_callback = py::none()) {
3233
stim::DetectorErrorModel input_dem = parse_py_object<stim::DetectorErrorModel>(dem);
33-
return SimplexConfig({input_dem, parallelize, window_length, window_slide_length, verbose});
34+
SimplexConfig cfg;
35+
cfg.dem = input_dem;
36+
cfg.parallelize = parallelize;
37+
cfg.window_length = window_length;
38+
cfg.window_slide_length = window_slide_length;
39+
std::function<void(const std::string&)> cb;
40+
bool active = false;
41+
if (!verbose_callback.is_none()) {
42+
py::function f = verbose_callback;
43+
cb = [f](const std::string& s) {
44+
py::gil_scoped_acquire gil;
45+
f(s);
46+
};
47+
active = true;
48+
}
49+
cfg.verbose_callback = cb;
50+
cfg.log_stream = CallbackStream(active, cfg.verbose_callback);
51+
return cfg;
3452
}
3553

3654
}; // namespace
@@ -42,12 +60,11 @@ void add_simplex_module(py::module& root) {
4260
py::class_<SimplexConfig>(m, "SimplexConfig")
4361
.def(py::init(&simplex_config_maker), py::arg("dem"), py::arg("parallelize") = false,
4462
py::arg("window_length") = 0, py::arg("window_slide_length") = 0,
45-
py::arg("verbose") = false)
63+
py::arg("verbose_callback") = py::none())
4664
.def_property("dem", &dem_getter<SimplexConfig>, &dem_setter<SimplexConfig>)
4765
.def_readwrite("parallelize", &SimplexConfig::parallelize)
4866
.def_readwrite("window_length", &SimplexConfig::window_length)
4967
.def_readwrite("window_slide_length", &SimplexConfig::window_slide_length)
50-
.def_readwrite("verbose", &SimplexConfig::verbose)
5168
.def("windowing_enabled", &SimplexConfig::windowing_enabled)
5269
.def("__str__", &SimplexConfig::str);
5370

src/simplex_main.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <argparse/argparse.hpp>
1616
#include <atomic>
1717
#include <fstream>
18+
#include <iostream>
1819
#include <nlohmann/json.hpp>
1920
#include <thread>
2021

@@ -134,6 +135,8 @@ struct Args {
134135

135136
void extract(SimplexConfig& config, std::vector<stim::SparseShot>& shots,
136137
std::unique_ptr<stim::MeasureRecordWriter>& writer) {
138+
config.verbose_callback = [](const std::string& s) { std::cout << s; };
139+
config.log_stream = CallbackStream(verbose, config.verbose_callback);
137140
// Get a circuit, if available
138141
stim::Circuit circuit;
139142
if (!circuit_path.empty()) {
@@ -261,7 +264,8 @@ struct Args {
261264
config.parallelize = enable_ilp_solver_parallelism;
262265
config.window_length = window_length;
263266
config.window_slide_length = window_slide_length;
264-
config.verbose = verbose;
267+
config.log_stream.active = verbose;
268+
config.log_stream.callback = config.verbose_callback;
265269
}
266270
};
267271

0 commit comments

Comments
 (0)