Skip to content

Commit b0eb738

Browse files
Merge branch 'main' into pr-stim-dem-decoder-init
2 parents 1253540 + 80bc70c commit b0eb738

2 files changed

Lines changed: 60 additions & 12 deletions

File tree

libs/qec/python/bindings/py_decoder.cpp

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,54 @@ sparse_binary_matrix_from_py_dict(const nb::dict &d) {
9090
"Sparse H dict layout must be \"nested_csc\" or \"nested_csr\".");
9191
}
9292

93+
/// Convert a dense 2-D NumPy uint8 array to sparse_binary_matrix without
94+
/// any intermediate dense tensor allocation. Strides are read directly so
95+
/// both C-contiguous (row-major) and Fortran-contiguous (column-major) arrays
96+
/// are handled efficiently: the inner loop always traverses contiguous memory.
97+
static sparse_binary_matrix
98+
make_sparse_from_dense(const nb::ndarray<nb::numpy, uint8_t> &arr) {
99+
if (arr.ndim() != 2)
100+
throw std::invalid_argument("H must be a 2-D uint8 array");
101+
const std::size_t num_rows = arr.shape(0);
102+
const std::size_t num_cols = arr.shape(1);
103+
const std::ptrdiff_t rs = arr.stride(0); // bytes per row step
104+
const std::ptrdiff_t cs = arr.stride(1); // bytes per col step
105+
const uint8_t *base = static_cast<const uint8_t *>(arr.data());
106+
107+
using index_t = sparse_binary_matrix::index_type;
108+
std::vector<index_t> ptr, idx;
109+
110+
// C-order: inner loop over columns is sequential → build CSR.
111+
// F-order: inner loop over rows is sequential → build CSC.
112+
if (cs <= rs) {
113+
ptr.reserve(num_rows + 1);
114+
ptr.push_back(0);
115+
for (std::size_t i = 0; i < num_rows; ++i) {
116+
for (std::size_t j = 0; j < num_cols; ++j) {
117+
if (base[i * rs + j * cs])
118+
idx.push_back(static_cast<index_t>(j));
119+
}
120+
ptr.push_back(static_cast<index_t>(idx.size()));
121+
}
122+
return sparse_binary_matrix::from_csr(static_cast<index_t>(num_rows),
123+
static_cast<index_t>(num_cols),
124+
std::move(ptr), std::move(idx));
125+
} else {
126+
ptr.reserve(num_cols + 1);
127+
ptr.push_back(0);
128+
for (std::size_t j = 0; j < num_cols; ++j) {
129+
for (std::size_t i = 0; i < num_rows; ++i) {
130+
if (base[i * rs + j * cs])
131+
idx.push_back(static_cast<index_t>(i));
132+
}
133+
ptr.push_back(static_cast<index_t>(idx.size()));
134+
}
135+
return sparse_binary_matrix::from_csc(static_cast<index_t>(num_rows),
136+
static_cast<index_t>(num_cols),
137+
std::move(ptr), std::move(idx));
138+
}
139+
}
140+
93141
class PyDecoder : public decoder {
94142
public:
95143
NB_TRAMPOLINE(decoder, 1);
@@ -770,13 +818,6 @@ void bindDecoder(nb::module_ &mod) {
770818

771819
cudaq::qec::sparse_binary_matrix H_sparse;
772820

773-
auto make_sparse_from_dense =
774-
[](const nb::ndarray<nb::numpy, uint8_t> &arr) {
775-
auto tensor_H = cudaqx::pcmToTensor(arr);
776-
return cudaq::qec::sparse_binary_matrix(
777-
tensor_H, cudaq::qec::sparse_binary_matrix_layout::csc);
778-
};
779-
780821
if (nb::isinstance<nb::dict>(H))
781822
H_sparse = sparse_binary_matrix_from_py_dict(nb::cast<nb::dict>(H));
782823
else

libs/qec/python/tests/test_decoder.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,18 @@ def test_decoder_initialization():
2929
assert hasattr(decoder, 'decode')
3030

3131

32-
def test_decoder_initialization_with_error():
33-
# We do not support column-major order (Fortran order)
34-
H_bad = np.zeros((10, 20), dtype=np.uint8, order='F')
35-
with pytest.raises(RuntimeError) as e:
36-
decoder = qec.get_decoder('single_error_lut_example', H_bad)
32+
def test_decoder_initialization_with_fortran_order():
33+
# Fortran-order (column-major) arrays are now handled via stride-aware
34+
# scanning and should work correctly.
35+
H_f = np.eye(10, 20, dtype=np.uint8, order='F')
36+
H_c = np.ascontiguousarray(H_f)
37+
decoder_f = qec.get_decoder('single_error_lut_example', H_f)
38+
decoder_c = qec.get_decoder('single_error_lut_example', H_c)
39+
syndrome = np.zeros(H_f.shape[0], dtype=np.uint8)
40+
r_f = decoder_f.decode(syndrome)
41+
r_c = decoder_c.decode(syndrome)
42+
assert r_f.converged == r_c.converged
43+
assert list(r_f.result) == list(r_c.result)
3744

3845

3946
def test_decoder_api():

0 commit comments

Comments
 (0)