Skip to content

Commit 795f5fd

Browse files
Support Stim DEM strings in get_decoder (#571)
## Description Adds unified decoder initialization from either a parity-check matrix or raw Stim detector error model text. This supports both existing PCM-based decoders and DEM-native decoders needed by #546. PCM-based decoders can parse DEM text into `H`, `O`, and `error_rate_vec` defaults, while DEM-native decoders can consume the raw DEM text without going through a lossy matrix conversion. ### API (`cudaq/qec/decoder.h`) - Adds `decoder_init = std::variant<sparse_binary_matrix, std::string>` as the decoder registry construction input. - Keeps existing PCM construction paths: - `get_decoder(name, H, options)` - `decoder::get(name, H, options)` - Adds DEM string construction through the same entry points: - `get_decoder(name, dem_text, options)` - `decoder::get(name, dem_text, options)` - Adds `dem_from_stim_text(dem_text)` to parse Stim DEM text into `detector_error_model`. - Adds helper routing for PCM-based decoders so DEM-derived `O` and `error_rate_vec` are supplied as defaults when not explicitly provided by the user. The DEM-to-detector_error_model parser is lossy: it extracts detector flips, observable flips, and per-error probabilities, but drops detector coordinates and separator-encoded correlation structure. DEM-native decoders should consume the raw string alternative in decoder_init. Python: `cudaq_qec.get_decoder(...)` now accepts Stim DEM strings. Python-registered decoders currently receive parsed `H` plus DEM-derived `O` / `error_rate_vec` defaults, not raw DEM text. ### Dependency / build QEC now provides its own Stim dependency via `FetchContent` when `libstim` is not already available, so standalone QEC builds do not depend on parent CUDA-Q interim build artifacts. ### Tests C++ gtests and Python tests cover: - DEM string construction through `get_decoder(...)` - PCM construction still working - user-provided options overriding DEM-derived defaults - DEM parse edge cases - DEMs without observables - `stim::DemTarget` category assumptions ### Out of scope / follow-ups Chromobius plugin integration itself; detector-coordinate storage on `detector_error_model`; optional PyMatching-specific improvements; user-facing Sphinx/RST docs. ## Runtime / performance impact N/A ## Self-review checklist Please confirm each item before requesting review. Check `[x]` or strike through and explain. ### Before requesting review - [x] I reviewed my own full diff in GitHub or my editor. - [x] PR is in Draft if it is not yet ready for review. - [x] Temporary / debugging changes have been removed. - [x] Local test logs reviewed; no unexplained warnings or errors. - [x] CI logs reviewed; no unexplained warnings or errors. - [x] Full CI has been run. ### Scope and size - [x] PR is under ~1000 lines, or an exception is justified in the description. - [x] Refactoring-only changes are isolated in their own PR(s). - [x] No existing tests were disabled or modified just to make this PR pass (if so, an issue has been raised). ### Tests - [x] New functionality has new tests. - [x] Tests fail if the new functionality is broken (including crashes), not just when it is missing. - [x] Negative tests added where exceptions are expected. - [x] Truth data added where simple `EXPECT_*` / `assert` checks are insufficient for algorithmic correctness. - [x] CI runtime impact considered; team notified if significant. ### Documentation - [x] Public-facing APIs have Doxygen docs. - [x] User-visible behavior changes have public docs, or a follow-up is tracked. - [x] User-facing docs for new features are in a **separate PR** held until release (the docs site publishes immediately on merge to the default branch, so feature docs must not land before the feature ships). ### Code style - [x] Naming follows the existing convention (`snake_case` vs `camelCase`) for the area being modified. ### Dependencies - [x] No new third-party dependencies, **or** the team has been notified and OSRB tickets filed. --------- Signed-off-by: vedika-saravanan <vsaravanan@nvidia.com>
1 parent ccb62c4 commit 795f5fd

16 files changed

Lines changed: 535 additions & 94 deletions

File tree

libs/qec/include/cudaq/qec/decoder.h

Lines changed: 124 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,16 @@
1212
#include "cuda-qx/core/heterogeneous_map.h"
1313
#include "cuda-qx/core/tensor.h"
1414
#include "sparse_binary_matrix.h"
15+
#include "cudaq/qec/detector_error_model.h"
16+
#include <algorithm>
17+
#include <functional>
1518
#include <future>
19+
#include <memory>
1620
#include <optional>
21+
#include <string>
22+
#include <string_view>
23+
#include <tuple>
24+
#include <variant>
1725
#include <vector>
1826

1927
namespace cudaq::qec {
@@ -24,6 +32,10 @@ using float_t = CUDAQX_QEC_FLOAT_TYPE;
2432
using float_t = double;
2533
#endif
2634

35+
/// Decoder construction input: either a parity-check matrix or raw Stim DEM
36+
/// text.
37+
using decoder_init = std::variant<sparse_binary_matrix, std::string>;
38+
2739
/// @brief Validates that all keys in a heterogeneous map are found in a list of
2840
/// acceptable types
2941
/// @param config The heterogeneous map to validate
@@ -122,8 +134,7 @@ class async_decoder_result {
122134
/// arbitrary constructor parameters that can be unique to each specific
123135
/// decoder.
124136
class decoder
125-
: public cudaqx::extension_point<decoder,
126-
const cudaq::qec::sparse_binary_matrix &,
137+
: public cudaqx::extension_point<decoder, const decoder_init &,
127138
const cudaqx::heterogeneous_map &> {
128139
private:
129140
struct rt_impl;
@@ -143,16 +154,16 @@ class decoder
143154
/// @brief Decode a single syndrome
144155
/// @param syndrome A vector of syndrome measurements where the floating point
145156
/// value is the probability that the syndrome measurement is a |1>. The
146-
/// length of the syndrome vector should be equal to \p syndrome_size.
147-
/// @returns Vector of length \p block_size with soft probabilities of errors
157+
/// length of the syndrome vector should be equal to `syndrome_size`.
158+
/// @returns Vector of length `block_size` with soft probabilities of errors
148159
/// in each index.
149160
virtual decoder_result decode(const std::vector<float_t> &syndrome) = 0;
150161

151162
/// @brief Decode a single syndrome
152163
/// @param syndrome An order-1 tensor of syndrome measurements where a 1 bit
153164
/// represents that the syndrome measurement is a |1>. The
154-
/// length of the syndrome vector should be equal to \p syndrome_size.
155-
/// @returns Vector of length \p block_size of errors in each index.
165+
/// length of the syndrome vector should be equal to `syndrome_size`.
166+
/// @returns Vector of length `block_size` of errors in each index.
156167
virtual decoder_result decode(const cudaqx::tensor<uint8_t> &syndrome);
157168

158169
/// @brief Decode a single syndrome
@@ -172,11 +183,49 @@ class decoder
172183
virtual std::vector<decoder_result>
173184
decode_batch(const std::vector<std::vector<float_t>> &syndrome);
174185

175-
/// @brief This `get` overload supports default values.
186+
/// @brief Construct a registered decoder by name.
187+
/// @param name The registered decoder name.
188+
/// @param init A parity-check matrix or raw Stim DEM string.
189+
/// @param param_map Optional decoder-specific parameters.
176190
static std::unique_ptr<decoder>
177-
get(const std::string &name, const cudaq::qec::sparse_binary_matrix &H,
191+
get(const std::string &name, const decoder_init &init,
178192
const cudaqx::heterogeneous_map &param_map = cudaqx::heterogeneous_map());
179193

194+
static std::unique_ptr<decoder>
195+
get(const std::string &name, const cudaq::qec::sparse_binary_matrix &H,
196+
const cudaqx::heterogeneous_map &param_map =
197+
cudaqx::heterogeneous_map()) {
198+
return get(name, decoder_init{H}, param_map);
199+
}
200+
201+
static std::unique_ptr<decoder>
202+
get(const std::string &name, const cudaqx::tensor<uint8_t> &H,
203+
const cudaqx::heterogeneous_map &param_map =
204+
cudaqx::heterogeneous_map()) {
205+
return get(name, cudaq::qec::sparse_binary_matrix(H), param_map);
206+
}
207+
208+
static std::unique_ptr<decoder>
209+
get(const std::string &name, const std::string &stim_dem_text,
210+
const cudaqx::heterogeneous_map &param_map =
211+
cudaqx::heterogeneous_map()) {
212+
return get(name, decoder_init{stim_dem_text}, param_map);
213+
}
214+
215+
static std::unique_ptr<decoder>
216+
get(const std::string &name, const char *stim_dem_text,
217+
const cudaqx::heterogeneous_map &param_map =
218+
cudaqx::heterogeneous_map()) {
219+
return get(name, decoder_init{std::string{stim_dem_text}}, param_map);
220+
}
221+
222+
static std::unique_ptr<decoder>
223+
get(const std::string &name, std::string_view stim_dem_text,
224+
const cudaqx::heterogeneous_map &param_map =
225+
cudaqx::heterogeneous_map()) {
226+
return get(name, decoder_init{std::string{stim_dem_text}}, param_map);
227+
}
228+
180229
std::size_t get_block_size() { return block_size; }
181230
std::size_t get_syndrome_size() { return syndrome_size; }
182231

@@ -435,6 +484,72 @@ inline void convert_vec_hard_to_soft(const std::vector<std::vector<t_hard>> &in,
435484
}
436485

437486
std::unique_ptr<decoder>
438-
get_decoder(const std::string &name, const cudaq::qec::sparse_binary_matrix &H,
487+
get_decoder(const std::string &name, const decoder_init &init,
439488
const cudaqx::heterogeneous_map options = {});
489+
490+
inline std::unique_ptr<decoder>
491+
get_decoder(const std::string &name, const cudaq::qec::sparse_binary_matrix &H,
492+
const cudaqx::heterogeneous_map options = {}) {
493+
return get_decoder(name, decoder_init{H}, options);
494+
}
495+
496+
inline std::unique_ptr<decoder>
497+
get_decoder(const std::string &name, const cudaqx::tensor<uint8_t> &H,
498+
const cudaqx::heterogeneous_map options = {}) {
499+
return get_decoder(name, cudaq::qec::sparse_binary_matrix(H), options);
500+
}
501+
502+
inline std::unique_ptr<decoder>
503+
get_decoder(const std::string &name, const std::string &stim_dem_text,
504+
const cudaqx::heterogeneous_map options = {}) {
505+
return get_decoder(name, decoder_init{stim_dem_text}, options);
506+
}
507+
508+
inline std::unique_ptr<decoder>
509+
get_decoder(const std::string &name, const char *stim_dem_text,
510+
const cudaqx::heterogeneous_map options = {}) {
511+
return get_decoder(name, decoder_init{std::string{stim_dem_text}}, options);
512+
}
513+
514+
inline std::unique_ptr<decoder>
515+
get_decoder(const std::string &name, std::string_view stim_dem_text,
516+
const cudaqx::heterogeneous_map options = {}) {
517+
return get_decoder(name, decoder_init{std::string{stim_dem_text}}, options);
518+
}
519+
520+
namespace details {
521+
// Declared here because `make_pcm_decoder` is a header-defined template.
522+
/// DEM-derived defaults; pointers alias into the source `dem`.
523+
struct dem_default_values {
524+
const cudaqx::tensor<uint8_t> *O = nullptr;
525+
const std::vector<double> *error_rate_vec = nullptr;
526+
};
527+
528+
/// Return DEM defaults for keys not already supplied by the user.
529+
dem_default_values dem_defaults_for_missing_keys(
530+
const std::function<bool(const std::string &)> &contains_user_key,
531+
const detector_error_model &dem);
532+
} // namespace details
533+
534+
/// If `init` holds DEM text, parse it and inject `"O"` / `"error_rate_vec"`
535+
/// defaults when absent.
536+
template <typename DecoderT>
537+
std::unique_ptr<decoder>
538+
make_pcm_decoder(const decoder_init &init,
539+
const cudaqx::heterogeneous_map &params) {
540+
if (const auto *H = std::get_if<cudaq::qec::sparse_binary_matrix>(&init))
541+
return std::make_unique<DecoderT>(*H, params);
542+
543+
const auto dem = dem_from_stim_text(std::get<std::string>(init));
544+
cudaqx::heterogeneous_map merged = params;
545+
const auto defaults = details::dem_defaults_for_missing_keys(
546+
[&](const std::string &key) { return merged.contains(key); }, dem);
547+
if (defaults.O)
548+
merged.insert("O", *defaults.O);
549+
if (defaults.error_rate_vec)
550+
merged.insert("error_rate_vec", *defaults.error_rate_vec);
551+
return std::make_unique<DecoderT>(
552+
cudaq::qec::sparse_binary_matrix(dem.detector_error_matrix), merged);
553+
}
554+
440555
} // namespace cudaq::qec

libs/qec/include/cudaq/qec/detector_error_model.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
#pragma once
99

1010
#include "cuda-qx/core/tensor.h"
11+
#include <cstddef>
12+
#include <cstdint>
1113
#include <optional>
14+
#include <string>
15+
#include <vector>
1216

1317
namespace cudaq::qec {
1418

@@ -18,9 +22,9 @@ namespace cudaq::qec {
1822
/// decoder to help make predictions about observables flips.
1923
///
2024
/// Shared size parameters among the matrix types.
21-
/// - \p detector_error_matrix: num_detectors x num_error_mechanisms [d, e]
22-
/// - \p error_rates: num_error_mechanisms
23-
/// - \p observables_flips_matrix: num_observables x num_error_mechanisms [k, e]
25+
/// - `detector_error_matrix`: num_detectors x num_error_mechanisms [d, e]
26+
/// - `error_rates`: num_error_mechanisms
27+
/// - `observables_flips_matrix`: num_observables x num_error_mechanisms [k, e]
2428
///
2529
/// @note The C++ API for this class may change in the future. The Python API is
2630
/// more likely to be backwards compatible.
@@ -32,7 +36,7 @@ struct detector_error_model {
3236
cudaqx::tensor<uint8_t> detector_error_matrix;
3337

3438
/// The list of weights has length equal to the number of columns of
35-
/// \p detector_error_matrix, which assigns a likelihood to each error
39+
/// `detector_error_matrix`, which assigns a likelihood to each error
3640
/// mechanism.
3741
std::vector<double> error_rates;
3842

@@ -65,4 +69,8 @@ struct detector_error_model {
6569
void canonicalize_for_rounds(uint32_t num_syndromes_per_round);
6670
};
6771

72+
/// Parse Stim DEM text into detector/observable flip matrices and error rates.
73+
/// This is lossy; DEM-native decoders should consume raw DEM text instead.
74+
detector_error_model dem_from_stim_text(const std::string &dem_text);
75+
6876
} // namespace cudaq::qec

libs/qec/lib/CMakeLists.txt

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
set(LIBRARY_NAME cudaq-qec)
1010

11+
include(FetchContent)
12+
1113
add_compile_options(-Wno-attributes)
1214

1315
find_package(CUDAToolkit REQUIRED)
@@ -44,7 +46,7 @@ set(QEC_SOURCES
4446
pcm_utils.cpp
4547
plugin_loader.cpp
4648
sparse_binary_matrix.cpp
47-
stabilizer_utils.cpp
49+
stabilizer_utils.cpp
4850
decoders/lut.cpp
4951
decoders/sliding_window.cpp
5052
version.cpp
@@ -58,6 +60,28 @@ list(APPEND QEC_SOURCES
5860
# FIXME?: This must be a shared library. Trying to build a static one will fail.
5961
add_library(${LIBRARY_NAME} SHARED ${QEC_SOURCES})
6062

63+
if(NOT TARGET libstim)
64+
FetchContent_Declare(
65+
stim
66+
GIT_REPOSITORY https://github.com/quantumlib/Stim.git
67+
GIT_TAG v1.15.0
68+
EXCLUDE_FROM_ALL
69+
)
70+
FetchContent_MakeAvailable(stim)
71+
endif()
72+
73+
if(NOT TARGET libstim)
74+
message(FATAL_ERROR
75+
"Stim FetchContent did not provide the libstim target.")
76+
endif()
77+
set_target_properties(${LIBRARY_NAME} PROPERTIES
78+
VISIBILITY_INLINES_HIDDEN ON
79+
)
80+
target_link_libraries(${LIBRARY_NAME} PRIVATE libstim)
81+
target_link_options(${LIBRARY_NAME} PRIVATE
82+
$<$<OR:$<CXX_COMPILER_ID:GNU>,$<CXX_COMPILER_ID:Clang>>:-Wl,--exclude-libs,libstim.a>
83+
)
84+
6185
add_subdirectory(decoders/plugins/example)
6286
add_subdirectory(decoders/plugins/pymatching)
6387

libs/qec/lib/decoder.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,7 @@
1717
#include <filesystem>
1818
#include <vector>
1919

20-
INSTANTIATE_REGISTRY(cudaq::qec::decoder,
21-
const cudaq::qec::sparse_binary_matrix &)
22-
INSTANTIATE_REGISTRY(cudaq::qec::decoder,
23-
const cudaq::qec::sparse_binary_matrix &,
20+
INSTANTIATE_REGISTRY(cudaq::qec::decoder, const cudaq::qec::decoder_init &,
2421
const cudaqx::heterogeneous_map &)
2522

2623
// Include decoder implementations AFTER registry instantiation
@@ -131,7 +128,7 @@ decoder::decode_async(const std::vector<float_t> &syndrome) {
131128
}
132129

133130
std::unique_ptr<decoder>
134-
decoder::get(const std::string &name, const cudaq::qec::sparse_binary_matrix &H,
131+
decoder::get(const std::string &name, const decoder_init &init,
135132
const cudaqx::heterogeneous_map &param_map) {
136133
auto [mutex, registry] = get_registry();
137134
std::lock_guard<std::recursive_mutex> lock(mutex);
@@ -141,9 +138,24 @@ decoder::get(const std::string &name, const cudaq::qec::sparse_binary_matrix &H,
141138
"invalid decoder requested: " + name +
142139
". Run with CUDAQ_LOG_LEVEL=info (environment variable) to see "
143140
"additional plugin diagnostics at startup.");
144-
return iter->second(H, param_map);
141+
return iter->second(init, param_map);
145142
}
146143

144+
namespace details {
145+
146+
dem_default_values dem_defaults_for_missing_keys(
147+
const std::function<bool(const std::string &)> &contains_user_key,
148+
const detector_error_model &dem) {
149+
dem_default_values out;
150+
if (!contains_user_key("O") && dem.num_observables() > 0)
151+
out.O = &dem.observables_flips_matrix;
152+
if (!contains_user_key("error_rate_vec"))
153+
out.error_rate_vec = &dem.error_rates;
154+
return out;
155+
}
156+
157+
} // namespace details
158+
147159
static uint32_t calculate_num_msyn_per_decode(
148160
const std::vector<std::vector<uint32_t>> &D_sparse) {
149161
uint32_t max_col = 0;
@@ -480,9 +492,9 @@ void decoder::reset_decoder() {
480492
}
481493

482494
std::unique_ptr<decoder> get_decoder(const std::string &name,
483-
const cudaq::qec::sparse_binary_matrix &H,
495+
const decoder_init &init,
484496
const cudaqx::heterogeneous_map options) {
485-
return decoder::get(name, H, options);
497+
return decoder::get(name, init, options);
486498
}
487499

488500
// Constructor function for auto-loading plugins

libs/qec/lib/decoders/lut.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,9 @@ class multi_error_lut : public decoder {
228228

229229
CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION(
230230
multi_error_lut, static std::unique_ptr<decoder> create(
231-
const cudaq::qec::sparse_binary_matrix &H,
231+
const cudaq::qec::decoder_init &init,
232232
const cudaqx::heterogeneous_map &params) {
233-
return std::make_unique<multi_error_lut>(H, params);
233+
return cudaq::qec::make_pcm_decoder<multi_error_lut>(init, params);
234234
})
235235
};
236236

@@ -246,9 +246,9 @@ class single_error_lut : public multi_error_lut {
246246

247247
CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION(
248248
single_error_lut, static std::unique_ptr<decoder> create(
249-
const cudaq::qec::sparse_binary_matrix &H,
249+
const cudaq::qec::decoder_init &init,
250250
const cudaqx::heterogeneous_map &params) {
251-
return std::make_unique<single_error_lut>(H, params);
251+
return cudaq::qec::make_pcm_decoder<single_error_lut>(init, params);
252252
})
253253
};
254254

libs/qec/lib/decoders/plugins/example/single_error_lut_example.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,10 @@ class single_error_lut_example : public decoder {
7777

7878
CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION(
7979
single_error_lut_example, static std::unique_ptr<decoder> create(
80-
const cudaq::qec::sparse_binary_matrix &H,
80+
const cudaq::qec::decoder_init &init,
8181
const cudaqx::heterogeneous_map &params) {
82-
return std::make_unique<single_error_lut_example>(H, params);
82+
return cudaq::qec::make_pcm_decoder<single_error_lut_example>(init,
83+
params);
8384
})
8485
};
8586

libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,9 @@ class pymatching : public decoder {
247247

248248
CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION(
249249
pymatching, static std::unique_ptr<decoder> create(
250-
const cudaq::qec::sparse_binary_matrix &H,
250+
const cudaq::qec::decoder_init &init,
251251
const cudaqx::heterogeneous_map &params) {
252-
return std::make_unique<pymatching>(H, params);
252+
return cudaq::qec::make_pcm_decoder<pymatching>(init, params);
253253
})
254254
};
255255

libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,9 +432,9 @@ class trt_decoder : public decoder {
432432

433433
CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION(
434434
trt_decoder, static std::unique_ptr<decoder> create(
435-
const cudaq::qec::sparse_binary_matrix &H,
435+
const cudaq::qec::decoder_init &init,
436436
const cudaqx::heterogeneous_map &params) {
437-
return std::make_unique<trt_decoder>(H, params);
437+
return cudaq::qec::make_pcm_decoder<trt_decoder>(init, params);
438438
})
439439

440440
private:

0 commit comments

Comments
 (0)