Skip to content

Commit f428c85

Browse files
fetches Stim directly with FetchContent when libstim is not already available
Signed-off-by: vedika-saravanan <vsaravanan@nvidia.com>
1 parent f6d7fdf commit f428c85

4 files changed

Lines changed: 84 additions & 60 deletions

File tree

libs/qec/lib/CMakeLists.txt

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
set(LIBRARY_NAME cudaq-qec)
1010

11+
include(FetchContent)
12+
1113
add_compile_options(-Wno-attributes)
1214

1315
find_package(CUDAToolkit REQUIRED)
@@ -59,18 +61,28 @@ list(APPEND QEC_SOURCES
5961
# FIXME?: This must be a shared library. Trying to build a static one will fail.
6062
add_library(${LIBRARY_NAME} SHARED ${QEC_SOURCES})
6163

62-
add_subdirectory(decoders/plugins/example)
63-
add_subdirectory(decoders/plugins/pymatching)
64+
if(NOT TARGET libstim)
65+
FetchContent_Declare(
66+
stim
67+
GIT_REPOSITORY https://github.com/quantumlib/Stim.git
68+
GIT_TAG v1.15.0
69+
EXCLUDE_FROM_ALL
70+
)
71+
FetchContent_MakeAvailable(stim)
72+
endif()
6473

6574
if(NOT TARGET libstim)
6675
message(FATAL_ERROR
67-
"libstim target not available; required by cudaq-qec for Stim DEM parsing.")
76+
"Stim FetchContent did not provide the libstim target.")
6877
endif()
6978
target_link_libraries(${LIBRARY_NAME} PRIVATE libstim)
7079
target_link_options(${LIBRARY_NAME} PRIVATE
7180
$<$<OR:$<CXX_COMPILER_ID:GNU>,$<CXX_COMPILER_ID:Clang>>:-Wl,--exclude-libs,libstim.a>
7281
)
7382

83+
add_subdirectory(decoders/plugins/example)
84+
add_subdirectory(decoders/plugins/pymatching)
85+
7486
# The TRT decoder plugin honors the tri-state `CUDAQ_QEC_BUILD_TRT_DECODER`
7587
# cache variable (AUTO/ON/OFF) declared in the parent CMakeLists.txt. Skip
7688
# descending entirely when the user explicitly opted out; otherwise let the

libs/qec/lib/decoder_stim_dem.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ dem_default_values dem_defaults_for_missing_keys(
1717
const std::function<bool(const std::string &)> &contains_user_key,
1818
const detector_error_model &dem) {
1919
dem_default_values out;
20-
if (!contains_user_key("O"))
20+
if (!contains_user_key("O") && dem.num_observables() > 0)
2121
out.O = &dem.observables_flips_matrix;
2222
if (!contains_user_key("error_rate_vec"))
2323
out.error_rate_vec = &dem.error_rates;

libs/qec/python/tests/test_decoder.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -762,12 +762,12 @@ def test_generate_random_pcm_signed_weight_rejects_negative():
762762
seed=1)
763763

764764

765-
def test_get_decoder_from_stim_dem():
765+
def test_get_decoder_accepts_stim_dem_string():
766766
dem_text = ("error(0.1) D0 L0\n"
767767
"error(0.1) D1 L0\n"
768768
"error(0.05) D0 D1\n")
769769

770-
decoder = qec.get_decoder_from_stim_dem("single_error_lut", dem_text)
770+
decoder = qec.get_decoder("single_error_lut", dem_text)
771771
assert decoder is not None
772772
assert decoder.get_syndrome_size() == 2
773773
assert decoder.get_block_size() == 3
@@ -784,19 +784,13 @@ def test_get_decoder_from_stim_dem():
784784
assert list(result.result) == expected, f"syndrome {syndrome}"
785785

786786

787-
def test_get_decoder_accepts_stim_dem_string():
788-
dem_text = ("error(0.1) D0 L0\n"
789-
"error(0.1) D1 L0\n"
790-
"error(0.05) D0 D1\n")
787+
def test_get_decoder_from_stim_dem_compatibility():
788+
dem_text = "error(0.1) D0 L0\n"
791789

792-
decoder = qec.get_decoder("single_error_lut", dem_text)
790+
decoder = qec.get_decoder_from_stim_dem("single_error_lut", dem_text)
793791
assert decoder is not None
794-
assert decoder.get_syndrome_size() == 2
795-
assert decoder.get_block_size() == 3
796-
797-
result = decoder.decode([1.0, 1.0])
798-
assert result.converged is True
799-
assert list(result.result) == [0.0, 0.0, 1.0]
792+
assert decoder.get_syndrome_size() == 1
793+
assert decoder.get_block_size() == 1
800794

801795

802796
def test_dem_from_stim_text_explicit_parse_then_get_decoder():
@@ -816,24 +810,31 @@ def test_dem_from_stim_text_explicit_parse_then_get_decoder():
816810
assert decoder.get_block_size() == 3
817811

818812

819-
def test_get_decoder_from_stim_dem_rejects_malformed_text():
813+
def test_get_decoder_rejects_malformed_stim_dem_text():
820814
with pytest.raises(RuntimeError):
821-
qec.get_decoder_from_stim_dem("single_error_lut", "not a valid DEM")
815+
qec.get_decoder("single_error_lut", "not a valid DEM")
822816

823817

824-
def test_get_decoder_from_stim_dem_rejects_unknown_decoder():
818+
def test_get_decoder_rejects_unknown_decoder_for_stim_dem_text():
825819
with pytest.raises(RuntimeError, match="__no_such_decoder__"):
826-
qec.get_decoder_from_stim_dem("__no_such_decoder__",
827-
"error(0.1) D0 L0\n")
820+
qec.get_decoder("__no_such_decoder__", "error(0.1) D0 L0\n")
828821

829822

830-
def test_get_decoder_from_stim_dem_user_O_wins_over_dem_derived():
823+
def test_get_decoder_user_O_wins_over_dem_derived():
831824
dem_text = ("error(0.1) D0 L0\n"
832825
"error(0.1) D1 L0\n"
833826
"error(0.05) D0 D1\n")
834827
bad_O = np.zeros((1, 4), dtype=np.uint8)
835828
with pytest.raises(RuntimeError):
836-
qec.get_decoder_from_stim_dem("pymatching", dem_text, O=bad_O)
829+
qec.get_decoder("pymatching", dem_text, O=bad_O)
830+
831+
832+
def test_get_decoder_stim_dem_without_observables_returns_errors():
833+
decoder = qec.get_decoder("pymatching", "error(0.1) D0\n")
834+
835+
result = decoder.decode([1.0])
836+
assert result.converged is True
837+
assert list(result.result) == [1.0]
837838

838839

839840
if __name__ == "__main__":

libs/qec/unittests/test_decoders.cpp

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -745,13 +745,13 @@ TEST(DecoderTest, GetBlockSizeAndSyndromeSize) {
745745
EXPECT_EQ(decoder2->get_syndrome_size(), new_syndrome_size);
746746
}
747747

748-
TEST(StimDemDecoderFactory, ConstructsLutDecoderFromStimDemText) {
748+
TEST(StimDemGetDecoder, ConstructsLutDecoderFromStimDemText) {
749749
const std::string dem_text = R"(error(0.1) D0 L0
750750
error(0.1) D1 L0
751751
error(0.05) D0 D1
752752
)";
753753

754-
auto d = cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text);
754+
auto d = cudaq::qec::get_decoder("single_error_lut", dem_text);
755755
ASSERT_NE(d, nullptr);
756756
EXPECT_EQ(d->get_syndrome_size(), 2u);
757757
EXPECT_EQ(d->get_block_size(), 3u);
@@ -778,26 +778,32 @@ error(0.05) D0 D1
778778
}
779779
}
780780

781-
TEST(StimDemDecoderFactory, UnifiedGetDecoderAcceptsStimDemString) {
781+
TEST(StimDemGetDecoder, StaticDecoderGetAcceptsStimDemString) {
782782
const std::string dem_text = R"(error(0.1) D0 L0
783783
error(0.1) D1 L0
784784
error(0.05) D0 D1
785785
)";
786786

787-
auto check = [&](std::unique_ptr<cudaq::qec::decoder> d) {
788-
ASSERT_NE(d, nullptr);
789-
EXPECT_EQ(d->get_syndrome_size(), 2u);
790-
EXPECT_EQ(d->get_block_size(), 3u);
791-
auto result = d->decode(std::vector<cudaq::qec::float_t>{1.0, 1.0});
792-
EXPECT_TRUE(result.converged);
793-
ASSERT_EQ(result.result.size(), 3u);
794-
EXPECT_FLOAT_EQ(result.result[2], 1.0);
795-
};
796-
check(cudaq::qec::get_decoder("single_error_lut", dem_text));
797-
check(cudaq::qec::decoder::get("single_error_lut", dem_text));
787+
auto d = cudaq::qec::decoder::get("single_error_lut", dem_text);
788+
ASSERT_NE(d, nullptr);
789+
EXPECT_EQ(d->get_syndrome_size(), 2u);
790+
EXPECT_EQ(d->get_block_size(), 3u);
791+
auto result = d->decode(std::vector<cudaq::qec::float_t>{1.0, 1.0});
792+
EXPECT_TRUE(result.converged);
793+
ASSERT_EQ(result.result.size(), 3u);
794+
EXPECT_FLOAT_EQ(result.result[2], 1.0);
795+
}
796+
797+
TEST(StimDemGetDecoder, DeprecatedStimDemHelperStillWorks) {
798+
const std::string dem_text = "error(0.1) D0 L0\n";
799+
800+
auto d = cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text);
801+
ASSERT_NE(d, nullptr);
802+
EXPECT_EQ(d->get_syndrome_size(), 1u);
803+
EXPECT_EQ(d->get_block_size(), 1u);
798804
}
799805

800-
TEST(StimDemDecoderFactory, UnifiedGetDecoderStillAcceptsParityCheckMatrix) {
806+
TEST(StimDemGetDecoder, StillAcceptsParityCheckMatrix) {
801807
cudaqx::tensor<uint8_t> H({2, 3});
802808
H.copy(std::vector<uint8_t>{1, 0, 1, 0, 1, 1}.data(), {2, 3});
803809
auto d = cudaq::qec::get_decoder("single_error_lut", H);
@@ -806,7 +812,7 @@ TEST(StimDemDecoderFactory, UnifiedGetDecoderStillAcceptsParityCheckMatrix) {
806812
EXPECT_EQ(d->get_block_size(), 3u);
807813
}
808814

809-
TEST(StimDemDecoderFactory, RepeatedDetectorOrObservableTargetsXorFold) {
815+
TEST(StimDemGetDecoder, RepeatedDetectorOrObservableTargetsXorFold) {
810816
const std::string dem_text = R"(error(0.1) D0 D0
811817
error(0.1) L0 L0
812818
)";
@@ -821,34 +827,40 @@ error(0.1) L0 L0
821827
<< "duplicate L0 in error 1 should XOR-cancel to 0";
822828
}
823829

824-
TEST(StimDemDecoderFactory, ThrowsOnProbabilityOutOfRange) {
830+
TEST(StimDemGetDecoder, DemWithoutObservablesDoesNotAddODefault) {
831+
auto dem = cudaq::qec::dem_from_stim_text("error(0.1) D0\n");
832+
auto defaults = cudaq::qec::dem_defaults_for_missing_keys(
833+
[](const std::string &) { return false; }, dem);
834+
835+
EXPECT_EQ(defaults.O, nullptr);
836+
ASSERT_NE(defaults.error_rate_vec, nullptr);
837+
EXPECT_EQ(defaults.error_rate_vec->size(), 1u);
838+
}
839+
840+
TEST(StimDemGetDecoder, ThrowsOnProbabilityOutOfRange) {
825841
const std::string dem_text = "error(1.5) D0\n";
826-
EXPECT_THROW(
827-
cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text),
828-
std::runtime_error);
842+
EXPECT_THROW(cudaq::qec::get_decoder("single_error_lut", dem_text),
843+
std::runtime_error);
829844
}
830845

831-
TEST(StimDemDecoderFactory, ThrowsOnMalformedStimDem) {
832-
EXPECT_THROW(cudaq::qec::get_decoder_from_stim_dem("single_error_lut",
833-
"not a valid DEM"),
846+
TEST(StimDemGetDecoder, ThrowsOnMalformedStimDem) {
847+
EXPECT_THROW(cudaq::qec::get_decoder("single_error_lut", "not a valid DEM"),
834848
std::runtime_error);
835849
}
836850

837-
TEST(StimDemDecoderFactory, ThrowsOnUnknownDecoderName) {
851+
TEST(StimDemGetDecoder, ThrowsOnUnknownDecoderName) {
838852
const std::string dem_text = "error(0.1) D0 L0\n";
839-
EXPECT_THROW(
840-
cudaq::qec::get_decoder_from_stim_dem("__no_such_decoder__", dem_text),
841-
std::runtime_error);
853+
EXPECT_THROW(cudaq::qec::get_decoder("__no_such_decoder__", dem_text),
854+
std::runtime_error);
842855
}
843856

844-
TEST(StimDemDecoderFactory, ThrowsOnEmptyErrorMechanisms) {
857+
TEST(StimDemGetDecoder, ThrowsOnEmptyErrorMechanisms) {
845858
const std::string dem_text = "detector(0, 0, 0)\n";
846-
EXPECT_THROW(
847-
cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text),
848-
std::runtime_error);
859+
EXPECT_THROW(cudaq::qec::get_decoder("single_error_lut", dem_text),
860+
std::runtime_error);
849861
}
850862

851-
TEST(StimDemDecoderFactory, StimDemTargetCategoriesAreExhaustive) {
863+
TEST(StimDemGetDecoder, StimDemTargetCategoriesAreExhaustive) {
852864
const std::vector<stim::DemTarget> samples = {
853865
stim::DemTarget::separator(),
854866
stim::DemTarget::relative_detector_id(0),
@@ -865,14 +877,13 @@ TEST(StimDemDecoderFactory, StimDemTargetCategoriesAreExhaustive) {
865877
}
866878
}
867879

868-
TEST(StimDemDecoderFactory, UserOptionsAreNotOverwritten) {
880+
TEST(StimDemGetDecoder, UserOptionsAreNotOverwritten) {
869881
const std::string dem_text = R"(error(0.1) D0 L0
870882
error(0.1) D1 L0
871883
error(0.05) D0 D1
872884
)";
873885
cudaqx::heterogeneous_map opts;
874886
opts.insert("error_rate_vec", std::vector<double>{0.5});
875-
EXPECT_THROW(
876-
cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text, opts),
877-
std::runtime_error);
887+
EXPECT_THROW(cudaq::qec::get_decoder("single_error_lut", dem_text, opts),
888+
std::runtime_error);
878889
}

0 commit comments

Comments
 (0)