Skip to content

Commit 6603369

Browse files
address pr comments
Signed-off-by: vedika-saravanan <vsaravanan@nvidia.com>
1 parent bccd851 commit 6603369

5 files changed

Lines changed: 16 additions & 194 deletions

File tree

libs/qec/include/cudaq/qec/decoder.h

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -445,37 +445,21 @@ dem_default_values dem_defaults_for_missing_keys(
445445
const std::function<bool(const std::string &)> &contains_user_key,
446446
const detector_error_model &dem);
447447

448-
/// @brief Creator function for a decoder constructed from a Stim DEM string.
449-
using stim_dem_decoder_creator = std::function<std::unique_ptr<decoder>(
450-
const std::string &, const cudaqx::heterogeneous_map &)>;
451-
452-
/// @brief Register a Stim-DEM-string creator for the named decoder.
453-
/// @param name Decoder name; same name used by `get_decoder_from_stim_dem`.
454-
/// @param creator Builds a decoder from the raw DEM string + options. Takes
455-
/// precedence over the H/O fallback; must not re-enter the registry (the
456-
/// factory copies it out before invoking).
457-
/// @see get_decoder_from_stim_dem
458-
void register_stim_dem_decoder_creator(const std::string &name,
459-
stim_dem_decoder_creator creator);
460-
461-
/// @brief Unregister a previously registered Stim-DEM-string creator. No-op if
462-
/// \p name has no registered creator.
463-
/// @see register_stim_dem_decoder_creator
464-
void unregister_stim_dem_decoder_creator(const std::string &name);
465-
466448
/// @brief Construct a decoder by name from a Stim detector error model text.
467449
///
468-
/// When a Stim-DEM creator is registered for \p name it is used directly.
469-
/// Otherwise the DEM is parsed and forwarded to the existing H-based path
470-
/// after injecting two derived entries into \p options if they are not
471-
/// already present:
450+
/// Thin wrapper over \c dem_from_stim_text: parses the DEM and forwards to
451+
/// the existing H-based \c decoder::get after injecting two derived entries
452+
/// into \p options if they are not already present:
472453
/// - `"O"` : `cudaqx::tensor<uint8_t>` observables_flips_matrix
473454
/// - `"error_rate_vec"` : `std::vector<double>` per-error probabilities
474455
/// User-supplied values for either key win over the DEM-derived ones.
475456
///
476-
/// @note Decoders that need full DEM metadata (e.g. Chromobius) must
477-
/// register a Stim-DEM creator; the fallback only extracts H/O/rates.
478-
/// @see register_stim_dem_decoder_creator
457+
/// @note Lossy: detector annotations, decomposition separators, and
458+
/// `error_ids` are dropped. Sufficient for matching-style / H-based
459+
/// decoders (LUT, NV, sliding_window, TRT, PyMatching). Decoders that
460+
/// need full DEM metadata (e.g. Chromobius detector color/basis) require
461+
/// the planned \c detector_coords extension on
462+
/// \c cudaq::qec::detector_error_model; tracked as a follow-up.
479463
std::unique_ptr<decoder>
480464
get_decoder_from_stim_dem(const std::string &name,
481465
const std::string &stim_dem_text,

libs/qec/include/cudaq/qec/detector_error_model.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,9 @@ struct detector_error_model {
7575
/// @note Lossy: only detector/observable flips and error probabilities
7676
/// are extracted. Annotations (`detector`, `logical_observable`),
7777
/// suggested-decomposition separators, and \p error_ids are dropped.
78-
/// Decoders that need the full DEM (e.g. Chromobius) must consume the
79-
/// raw string via `register_stim_dem_decoder_creator`.
78+
/// Decoders that need full DEM metadata (e.g. Chromobius detector
79+
/// color/basis) require the planned \p detector_coords extension on
80+
/// \p detector_error_model; tracked as a follow-up.
8081
detector_error_model dem_from_stim_text(const std::string &dem_text);
8182

8283
} // namespace cudaq::qec

libs/qec/lib/decoder_stim_dem.cpp

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,18 @@
11
/*******************************************************************************
22
* Copyright (c) 2024 - 2026 NVIDIA Corporation & Affiliates. *
33
* All rights reserved. *
4-
* *
54
* This source code and the accompanying materials are made available under *
65
* the terms of the Apache License 2.0 which accompanies this distribution. *
76
******************************************************************************/
87

98
#include "cudaq/qec/decoder.h"
109
#include "cudaq/qec/detector_error_model.h"
1110

12-
#include <mutex>
1311
#include <stdexcept>
1412
#include <string>
15-
#include <unordered_map>
1613

1714
namespace cudaq::qec {
1815

19-
namespace {
20-
21-
// std::mutex is enough: the factory copies the creator out before
22-
// invoking, so creators cannot re-enter the registry.
23-
struct stim_dem_registry {
24-
std::mutex &mutex;
25-
std::unordered_map<std::string, stim_dem_decoder_creator> &map;
26-
};
27-
stim_dem_registry get_stim_dem_registry() {
28-
// Heap-allocated to outlive static destructors (plugin dlclose unregister
29-
// path); matches the cudaqx extension_point pattern. See extension_point.h.
30-
static std::mutex *mutex = new std::mutex();
31-
static auto *map =
32-
new std::unordered_map<std::string, stim_dem_decoder_creator>();
33-
return {*mutex, *map};
34-
}
35-
36-
} // namespace
37-
3816
dem_default_values dem_defaults_for_missing_keys(
3917
const std::function<bool(const std::string &)> &contains_user_key,
4018
const detector_error_model &dem) {
@@ -46,34 +24,10 @@ dem_default_values dem_defaults_for_missing_keys(
4624
return out;
4725
}
4826

49-
void register_stim_dem_decoder_creator(const std::string &name,
50-
stim_dem_decoder_creator creator) {
51-
auto reg = get_stim_dem_registry();
52-
std::lock_guard<std::mutex> lock(reg.mutex);
53-
reg.map[name] = std::move(creator);
54-
}
55-
56-
void unregister_stim_dem_decoder_creator(const std::string &name) {
57-
auto reg = get_stim_dem_registry();
58-
std::lock_guard<std::mutex> lock(reg.mutex);
59-
reg.map.erase(name);
60-
}
61-
6227
std::unique_ptr<decoder>
6328
get_decoder_from_stim_dem(const std::string &name,
6429
const std::string &stim_dem_text,
6530
const cudaqx::heterogeneous_map &options) {
66-
stim_dem_decoder_creator creator;
67-
{
68-
auto reg = get_stim_dem_registry();
69-
std::lock_guard<std::mutex> lock(reg.mutex);
70-
auto iter = reg.map.find(name);
71-
if (iter != reg.map.end())
72-
creator = iter->second;
73-
}
74-
if (creator)
75-
return creator(stim_dem_text, options);
76-
7731
if (!decoder::is_registered(name))
7832
throw std::runtime_error(
7933
"get_decoder_from_stim_dem: decoder \"" + name +

libs/qec/python/bindings/py_decoder.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -747,13 +747,12 @@ void bindDecoder(nb::module_ &mod) {
747747
hetMapFromKwargs(options));
748748
},
749749
"Construct a decoder by name from a Stim detector error model string. "
750-
"Observables and per-error rates from the DEM are injected into options "
751-
"under keys \"O\" and \"error_rate_vec\" when no registered Stim-DEM "
752-
"creator is found. User-supplied values for either key win over the "
750+
"Thin wrapper over dem_from_stim_text: observables and per-error rates "
751+
"from the DEM are injected into options under keys \"O\" and "
752+
"\"error_rate_vec\". User-supplied values for either key win over the "
753753
"DEM-derived ones. Python decoders registered via @qec.decoder receive "
754754
"the parsed H and O as numpy.ndarray and error_rate_vec as a 1-D "
755-
"numpy.ndarray of float64; to register a native DEM consumer, use the "
756-
"C++ register_stim_dem_decoder_creator API.");
755+
"numpy.ndarray of float64.");
757756

758757
qecmod.def(
759758
"get_sorted_pcm_column_indices",

libs/qec/unittests/test_decoders.cpp

Lines changed: 0 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -764,33 +764,6 @@ error(0.05) D0 D1
764764
}
765765
}
766766

767-
TEST(StimDemDecoderFactory, RegisteredCreatorIsUsed) {
768-
// static: the registry outlives this test, so the lambda must not
769-
// capture a stack reference.
770-
static bool registered_creator_was_called = false;
771-
registered_creator_was_called = false;
772-
// RAII guard: restore the registry slot on any exit path.
773-
struct CreatorGuard {
774-
const char *name;
775-
~CreatorGuard() { cudaq::qec::unregister_stim_dem_decoder_creator(name); }
776-
} guard{"__stim_dem_test_decoder__"};
777-
cudaq::qec::register_stim_dem_decoder_creator(
778-
"__stim_dem_test_decoder__",
779-
[](const std::string &dem_text, const cudaqx::heterogeneous_map &)
780-
-> std::unique_ptr<cudaq::qec::decoder> {
781-
registered_creator_was_called = true;
782-
EXPECT_EQ(dem_text, "passthrough");
783-
cudaqx::tensor<uint8_t> H({2u, 2u});
784-
cudaqx::heterogeneous_map empty;
785-
return cudaq::qec::decoder::get("single_error_lut", H, empty);
786-
});
787-
788-
auto d = cudaq::qec::get_decoder_from_stim_dem("__stim_dem_test_decoder__",
789-
"passthrough");
790-
EXPECT_TRUE(registered_creator_was_called);
791-
ASSERT_NE(d, nullptr);
792-
}
793-
794767
TEST(StimDemDecoderFactory, RepeatedDetectorOrObservableTargetsXorFold) {
795768
const std::string dem_text = R"(error(0.1) D0 D0
796769
error(0.1) L0 L0
@@ -854,34 +827,6 @@ TEST(StimDemDecoderFactory, StimDemTargetCategoriesAreExhaustive) {
854827
}
855828
}
856829

857-
TEST(StimDemDecoderFactory, RegisteredCreatorTakesPrecedenceOverFallback) {
858-
// Real decoder name on purpose: pins creator-over-fallback for an
859-
// existing decoder (a sentinel name would just retest
860-
// RegisteredCreatorIsUsed). Mutates the real "single_error_lut" slot;
861-
// safe only because gtest runs tests serially in a binary. If this
862-
// suite is ever parallelized, register against a sentinel name instead.
863-
static bool creator_was_called = false;
864-
creator_was_called = false;
865-
struct CreatorGuard {
866-
const char *name;
867-
~CreatorGuard() { cudaq::qec::unregister_stim_dem_decoder_creator(name); }
868-
} guard{"single_error_lut"};
869-
cudaq::qec::register_stim_dem_decoder_creator(
870-
"single_error_lut",
871-
[](const std::string &, const cudaqx::heterogeneous_map &)
872-
-> std::unique_ptr<cudaq::qec::decoder> {
873-
creator_was_called = true;
874-
cudaqx::tensor<uint8_t> H({2u, 2u});
875-
cudaqx::heterogeneous_map empty;
876-
return cudaq::qec::decoder::get("single_error_lut", H, empty);
877-
});
878-
879-
const std::string dem_text = "error(0.1) D0 L0\n";
880-
auto d = cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text);
881-
EXPECT_TRUE(creator_was_called);
882-
ASSERT_NE(d, nullptr);
883-
}
884-
885830
TEST(StimDemDecoderFactory, UserOptionsAreNotOverwritten) {
886831
const std::string dem_text = R"(error(0.1) D0 L0
887832
error(0.1) D1 L0
@@ -893,64 +838,3 @@ error(0.05) D0 D1
893838
cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text, opts),
894839
std::runtime_error);
895840
}
896-
897-
TEST(StimDemDecoderFactory, UserSuppliedObservablesAreNotOverwritten) {
898-
// Symmetric with UserOptionsAreNotOverwritten but for "O", via an
899-
// echo creator (decoder-validation-independent).
900-
static std::vector<std::size_t> observed_O_shape;
901-
observed_O_shape.clear();
902-
struct CreatorGuard {
903-
const char *name;
904-
~CreatorGuard() { cudaq::qec::unregister_stim_dem_decoder_creator(name); }
905-
} guard{"__stim_dem_echo_O__"};
906-
cudaq::qec::register_stim_dem_decoder_creator(
907-
"__stim_dem_echo_O__",
908-
[](const std::string &, const cudaqx::heterogeneous_map &opts)
909-
-> std::unique_ptr<cudaq::qec::decoder> {
910-
if (opts.contains("O")) {
911-
auto O = opts.get<cudaqx::tensor<uint8_t>>("O");
912-
observed_O_shape = O.shape();
913-
}
914-
cudaqx::tensor<uint8_t> H({2u, 2u});
915-
cudaqx::heterogeneous_map empty;
916-
return cudaq::qec::decoder::get("single_error_lut", H, empty);
917-
});
918-
919-
// Distinctive shape; a match proves the user's O reached the creator.
920-
cudaqx::tensor<uint8_t> user_O({7u, 11u});
921-
cudaqx::heterogeneous_map opts;
922-
opts.insert("O", user_O);
923-
auto d = cudaq::qec::get_decoder_from_stim_dem("__stim_dem_echo_O__",
924-
"error(0.1) D0 L0\n", opts);
925-
ASSERT_NE(d, nullptr);
926-
ASSERT_EQ(observed_O_shape.size(), 2u);
927-
EXPECT_EQ(observed_O_shape[0], 7u);
928-
EXPECT_EQ(observed_O_shape[1], 11u);
929-
}
930-
931-
TEST(StimDemDecoderFactory, RegisteredCreatorReceivesUserOptionsVerbatim) {
932-
static std::vector<double> observed_rates;
933-
observed_rates.clear();
934-
struct CreatorGuard {
935-
const char *name;
936-
~CreatorGuard() { cudaq::qec::unregister_stim_dem_decoder_creator(name); }
937-
} guard{"__stim_dem_echo__"};
938-
cudaq::qec::register_stim_dem_decoder_creator(
939-
"__stim_dem_echo__",
940-
[](const std::string &, const cudaqx::heterogeneous_map &opts)
941-
-> std::unique_ptr<cudaq::qec::decoder> {
942-
if (opts.contains("error_rate_vec"))
943-
observed_rates = opts.get<std::vector<double>>("error_rate_vec");
944-
cudaqx::tensor<uint8_t> H({2u, 2u});
945-
cudaqx::heterogeneous_map empty;
946-
return cudaq::qec::decoder::get("single_error_lut", H, empty);
947-
});
948-
949-
const std::vector<double> user_rates = {0.42, 0.13, 0.07};
950-
cudaqx::heterogeneous_map opts;
951-
opts.insert("error_rate_vec", user_rates);
952-
auto d = cudaq::qec::get_decoder_from_stim_dem("__stim_dem_echo__",
953-
"error(0.5) D0\n", opts);
954-
ASSERT_NE(d, nullptr);
955-
EXPECT_EQ(observed_rates, user_rates);
956-
}

0 commit comments

Comments
 (0)