Skip to content

Commit bccd851

Browse files
fix ci failure
Signed-off-by: vedika-saravanan <vsaravanan@nvidia.com>
1 parent e4717e9 commit bccd851

10 files changed

Lines changed: 222 additions & 55 deletions

File tree

libs/qec/include/cudaq/qec/decoder.h

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,11 +431,29 @@ std::unique_ptr<decoder>
431431
get_decoder(const std::string &name, const cudaqx::tensor<uint8_t> &H,
432432
const cudaqx::heterogeneous_map options = {});
433433

434+
struct detector_error_model;
435+
436+
/// @brief DEM-derived defaults; pointers alias into the source `dem`.
437+
struct dem_default_values {
438+
const cudaqx::tensor<uint8_t> *O = nullptr;
439+
const std::vector<double> *error_rate_vec = nullptr;
440+
};
441+
442+
/// @brief Return DEM defaults for any key not already supplied by the user.
443+
/// Shared by `get_decoder_from_stim_dem` and its Python binding.
444+
dem_default_values dem_defaults_for_missing_keys(
445+
const std::function<bool(const std::string &)> &contains_user_key,
446+
const detector_error_model &dem);
447+
434448
/// @brief Creator function for a decoder constructed from a Stim DEM string.
435449
using stim_dem_decoder_creator = std::function<std::unique_ptr<decoder>(
436450
const std::string &, const cudaqx::heterogeneous_map &)>;
437451

438452
/// @brief Register a Stim-DEM-string creator for the named decoder.
453+
/// @param name Decoder name; same name used by `get_decoder_from_stim_dem`.
454+
/// @param creator Builds a decoder from the raw DEM string + options. Takes
455+
/// precedence over the H/O fallback; must not re-enter the registry (the
456+
/// factory copies it out before invoking).
439457
/// @see get_decoder_from_stim_dem
440458
void register_stim_dem_decoder_creator(const std::string &name,
441459
stim_dem_decoder_creator creator);
@@ -451,12 +469,15 @@ void unregister_stim_dem_decoder_creator(const std::string &name);
451469
/// Otherwise the DEM is parsed and forwarded to the existing H-based path
452470
/// after injecting two derived entries into \p options if they are not
453471
/// already present:
454-
/// - \c "O" : \c cudaqx::tensor<uint8_t> observables_flips_matrix
455-
/// - \c "error_rate_vec" : \c std::vector<double> per-error probabilities
472+
/// - `"O"` : `cudaqx::tensor<uint8_t>` observables_flips_matrix
473+
/// - `"error_rate_vec"` : `std::vector<double>` per-error probabilities
456474
/// User-supplied values for either key win over the DEM-derived ones.
475+
///
476+
/// @note Decoders that need full DEM metadata (e.g. Chromobius) must
477+
/// register a Stim-DEM creator; the fallback only extracts H/O/rates.
457478
/// @see register_stim_dem_decoder_creator
458479
std::unique_ptr<decoder>
459480
get_decoder_from_stim_dem(const std::string &name,
460481
const std::string &stim_dem_text,
461-
const cudaqx::heterogeneous_map options = {});
482+
const cudaqx::heterogeneous_map &options = {});
462483
} // namespace cudaq::qec

libs/qec/include/cudaq/qec/detector_error_model.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,16 @@ struct detector_error_model {
6767
};
6868

6969
/// @brief Parse a Stim detector error model text into a
70-
/// \p cudaq::qec::detector_error_model. Each \c error instruction in the DEM
70+
/// \p cudaq::qec::detector_error_model. Each `error` instruction in the DEM
7171
/// becomes a single column in \p detector_error_matrix and
7272
/// \p observables_flips_matrix; suggested decomposition separators are
7373
/// folded into the same column.
74+
///
75+
/// @note Lossy: only detector/observable flips and error probabilities
76+
/// are extracted. Annotations (`detector`, `logical_observable`),
77+
/// suggested-decomposition separators, and \p error_ids are dropped.
78+
/// Decoders that need the full DEM (e.g. Chromobius) must consume the
79+
/// raw string via `register_stim_dem_decoder_creator`.
7480
detector_error_model dem_from_stim_text(const std::string &dem_text);
7581

7682
} // namespace cudaq::qec

libs/qec/lib/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ add_library(${LIBRARY_NAME} SHARED ${QEC_SOURCES})
6161
add_subdirectory(decoders/plugins/example)
6262
add_subdirectory(decoders/plugins/pymatching)
6363

64+
# libstim comes from the parent build (CUDA-Q).
6465
if(NOT TARGET libstim)
6566
message(FATAL_ERROR
6667
"libstim target not available; required by cudaq-qec for Stim DEM parsing.")

libs/qec/lib/decoder_stim_dem.cpp

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,40 +18,55 @@ namespace cudaq::qec {
1818

1919
namespace {
2020

21+
// std::mutex is enough: the factory copies the creator out before
22+
// invoking, so creators cannot re-enter the registry.
2123
struct stim_dem_registry {
22-
std::recursive_mutex &mutex;
24+
std::mutex &mutex;
2325
std::unordered_map<std::string, stim_dem_decoder_creator> &map;
2426
};
2527
stim_dem_registry get_stim_dem_registry() {
26-
static std::recursive_mutex *mutex = new std::recursive_mutex();
28+
// Heap-allocated to outlive static destructors (plugin dlclose unregister
29+
// path); matches the cudaqx extension_point pattern. See extension_point.h.
30+
static std::mutex *mutex = new std::mutex();
2731
static auto *map =
2832
new std::unordered_map<std::string, stim_dem_decoder_creator>();
2933
return {*mutex, *map};
3034
}
3135

3236
} // namespace
3337

38+
dem_default_values dem_defaults_for_missing_keys(
39+
const std::function<bool(const std::string &)> &contains_user_key,
40+
const detector_error_model &dem) {
41+
dem_default_values out;
42+
if (!contains_user_key("O"))
43+
out.O = &dem.observables_flips_matrix;
44+
if (!contains_user_key("error_rate_vec"))
45+
out.error_rate_vec = &dem.error_rates;
46+
return out;
47+
}
48+
3449
void register_stim_dem_decoder_creator(const std::string &name,
3550
stim_dem_decoder_creator creator) {
3651
auto reg = get_stim_dem_registry();
37-
std::lock_guard<std::recursive_mutex> lock(reg.mutex);
52+
std::lock_guard<std::mutex> lock(reg.mutex);
3853
reg.map[name] = std::move(creator);
3954
}
4055

4156
void unregister_stim_dem_decoder_creator(const std::string &name) {
4257
auto reg = get_stim_dem_registry();
43-
std::lock_guard<std::recursive_mutex> lock(reg.mutex);
58+
std::lock_guard<std::mutex> lock(reg.mutex);
4459
reg.map.erase(name);
4560
}
4661

4762
std::unique_ptr<decoder>
4863
get_decoder_from_stim_dem(const std::string &name,
4964
const std::string &stim_dem_text,
50-
const cudaqx::heterogeneous_map options) {
65+
const cudaqx::heterogeneous_map &options) {
5166
stim_dem_decoder_creator creator;
5267
{
5368
auto reg = get_stim_dem_registry();
54-
std::lock_guard<std::recursive_mutex> lock(reg.mutex);
69+
std::lock_guard<std::mutex> lock(reg.mutex);
5570
auto iter = reg.map.find(name);
5671
if (iter != reg.map.end())
5772
creator = iter->second;
@@ -68,10 +83,13 @@ get_decoder_from_stim_dem(const std::string &name,
6883
auto dem = dem_from_stim_text(stim_dem_text);
6984

7085
cudaqx::heterogeneous_map merged = options;
71-
if (!merged.contains("O"))
72-
merged.insert("O", dem.observables_flips_matrix);
73-
if (!merged.contains("error_rate_vec"))
74-
merged.insert("error_rate_vec", dem.error_rates);
86+
// Keep in sync with the Python binding in py_decoder.cpp.
87+
auto defaults = dem_defaults_for_missing_keys(
88+
[&](const std::string &key) { return merged.contains(key); }, dem);
89+
if (defaults.O)
90+
merged.insert("O", *defaults.O);
91+
if (defaults.error_rate_vec)
92+
merged.insert("error_rate_vec", *defaults.error_rate_vec);
7593

7694
return decoder::get(name, dem.detector_error_matrix, merged);
7795
}

libs/qec/lib/detector_error_model.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@
1515
namespace cudaq::qec {
1616

1717
detector_error_model dem_from_stim_text(const std::string &dem_text) {
18-
stim::DetectorErrorModel dem(dem_text);
18+
auto dem = [&dem_text]() {
19+
try {
20+
return stim::DetectorErrorModel(dem_text);
21+
} catch (const std::exception &e) {
22+
throw std::runtime_error(std::string("Stim DEM parse failed: ") +
23+
e.what());
24+
}
25+
}();
1926
const std::size_t num_detectors =
2027
static_cast<std::size_t>(dem.count_detectors());
2128
const std::size_t num_observables =
@@ -46,6 +53,15 @@ detector_error_model dem_from_stim_text(const std::string &dem_text) {
4653
dets.push_back(static_cast<std::size_t>(target.val()));
4754
} else if (target.is_observable_id()) {
4855
obs.push_back(static_cast<std::size_t>(target.val()));
56+
} else {
57+
// Forward-compat tripwire; unreachable today (stim's three
58+
// DemTarget categories are exhaustive -- pinned by
59+
// StimDemTargetCategoriesAreExhaustive).
60+
throw std::runtime_error(
61+
"Stim DEM error instruction (index " +
62+
std::to_string(instruction_index) +
63+
") contains an unsupported target kind; only D* (detector) and "
64+
"L* (observable) targets are supported by the fallback parser");
4965
}
5066
}
5167
detector_hits.push_back(std::move(dets));
@@ -55,6 +71,11 @@ detector_error_model dem_from_stim_text(const std::string &dem_text) {
5571
});
5672

5773
const std::size_t num_errors = rates.size();
74+
// Reject zero-column H at the boundary instead of letting decoders
75+
// crash with block_size == 0.
76+
if (num_errors == 0)
77+
throw std::runtime_error(
78+
"Stim DEM contains no error mechanisms after flattening");
5879
detector_error_model result;
5980
result.detector_error_matrix =
6081
cudaqx::tensor<uint8_t>({num_detectors, num_errors});

libs/qec/python/bindings/py_decoder.cpp

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,22 @@ makeBatchDecoderResult(const std::vector<decoder_result> &results) {
288288
};
289289
}
290290

291+
// Wrap a borrowed cudaqx buffer in a NumPy array and force a Python-side copy,
292+
// so the returned object owns its data.
293+
nb::object toPyArray(const cudaqx::tensor<uint8_t> &t) {
294+
size_t shape[2] = {t.shape()[0], t.shape()[1]};
295+
auto arr = nb::ndarray<nb::numpy, uint8_t>(const_cast<uint8_t *>(t.data()), 2,
296+
shape, nb::none());
297+
return nb::cast(arr).attr("copy")();
298+
}
299+
300+
nb::object toPyArray(const std::vector<double> &v) {
301+
size_t shape[1] = {v.size()};
302+
auto arr = nb::ndarray<nb::numpy, double>(const_cast<double *>(v.data()), 1,
303+
shape, nb::none());
304+
return nb::cast(arr).attr("copy")();
305+
}
306+
291307
} // namespace
292308

293309
void bindDecoder(nb::module_ &mod) {
@@ -713,26 +729,16 @@ void bindDecoder(nb::module_ &mod) {
713729
if (PyDecoderRegistry::contains(name)) {
714730
auto dem = dem_from_stim_text(dem_text);
715731

716-
if (!options.contains("O")) {
717-
const auto &O = dem.observables_flips_matrix;
718-
size_t shape[2] = {O.shape()[0], O.shape()[1]};
719-
auto O_arr = nb::ndarray<nb::numpy, uint8_t>(
720-
const_cast<uint8_t *>(O.data()), 2, shape, nb::none());
721-
options["O"] = nb::cast(O_arr).attr("copy")();
722-
}
723-
if (!options.contains("error_rate_vec")) {
724-
const auto &rates = dem.error_rates;
725-
size_t rates_shape[1] = {rates.size()};
726-
auto rates_arr = nb::ndarray<nb::numpy, double>(
727-
const_cast<double *>(rates.data()), 1, rates_shape, nb::none());
728-
options["error_rate_vec"] = nb::cast(rates_arr).attr("copy")();
729-
}
732+
// Keep in sync with the C++ fallback in decoder_stim_dem.cpp.
733+
auto defaults = dem_defaults_for_missing_keys(
734+
[&](const std::string &key) { return options.contains(key); },
735+
dem);
736+
if (defaults.O)
737+
options["O"] = toPyArray(*defaults.O);
738+
if (defaults.error_rate_vec)
739+
options["error_rate_vec"] = toPyArray(*defaults.error_rate_vec);
730740

731-
const auto &H = dem.detector_error_matrix;
732-
size_t H_shape[2] = {H.shape()[0], H.shape()[1]};
733-
auto H_arr = nb::ndarray<nb::numpy, uint8_t>(
734-
const_cast<uint8_t *>(H.data()), 2, H_shape, nb::none());
735-
nb::object H_obj = nb::cast(H_arr).attr("copy")();
741+
nb::object H_obj = toPyArray(dem.detector_error_matrix);
736742
return PyDecoderRegistry::get_decoder(
737743
name, nb::cast<nb::ndarray<nb::numpy, uint8_t>>(H_obj), options);
738744
}

libs/qec/python/cudaq_qec/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def checked_decode_batch(self, *args, **kwargs):
8484
get_code = qecrt.get_code
8585
get_available_codes = qecrt.get_available_codes
8686
get_decoder = qecrt.get_decoder
87+
get_decoder_from_stim_dem = qecrt.get_decoder_from_stim_dem
8788
DecoderResult = qecrt.DecoderResult
8889
BatchDecoderResult = qecrt.BatchDecoderResult
8990
DetectorErrorModel = qecrt.DetectorErrorModel

libs/qec/python/tests/test_decoder.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -528,13 +528,25 @@ def test_get_decoder_from_stim_dem():
528528

529529

530530
def test_get_decoder_from_stim_dem_rejects_malformed_text():
531-
with pytest.raises(Exception):
531+
with pytest.raises(RuntimeError):
532532
qec.get_decoder_from_stim_dem("single_error_lut", "not a valid DEM")
533533

534534

535535
def test_get_decoder_from_stim_dem_rejects_unknown_decoder():
536+
with pytest.raises(RuntimeError, match="__no_such_decoder__"):
537+
qec.get_decoder_from_stim_dem("__no_such_decoder__",
538+
"error(0.1) D0 L0\n")
539+
540+
541+
def test_get_decoder_from_stim_dem_user_O_wins_over_dem_derived():
542+
# Wrong-shape user O trips PyMatching's validation; silent overwrite
543+
# by the DEM-derived O would suppress the throw.
544+
dem_text = ("error(0.1) D0 L0\n"
545+
"error(0.1) D1 L0\n"
546+
"error(0.05) D0 D1\n")
547+
bad_O = np.zeros((1, 4), dtype=np.uint8)
536548
with pytest.raises(RuntimeError):
537-
qec.get_decoder_from_stim_dem("__no_such_decoder__", "error(0.1) D0\n")
549+
qec.get_decoder_from_stim_dem("pymatching", dem_text, O=bad_O)
538550

539551

540552
if __name__ == "__main__":

libs/qec/unittests/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ find_package(CUDAToolkit REQUIRED)
3535
add_compile_options(-Wno-attributes)
3636

3737
add_executable(test_decoders test_decoders.cpp decoders/sample_decoder.cpp)
38-
target_link_libraries(test_decoders PRIVATE GTest::gtest_main cudaq-qec cudaq-qec-realtime-decoding cudaq::cudaq)
38+
# Direct libstim link for StimDemTargetCategoriesAreExhaustive;
39+
# cudaq-qec hides stim symbols via --exclude-libs.
40+
target_link_libraries(test_decoders PRIVATE GTest::gtest_main cudaq-qec cudaq-qec-realtime-decoding cudaq::cudaq libstim)
3941
add_dependencies(CUDAQXQECUnitTests test_decoders)
4042
gtest_discover_tests(test_decoders)
4143

0 commit comments

Comments
 (0)