Skip to content

Commit 008a3d2

Browse files
authored
Return Python decode result as NumPy arrays (#558)
## Description Updates Python decoder result access so decoded correction data is exposed as NumPy arrays instead of Python lists. This is a breaking Python API change: `DecoderResult.result` now returns a 1-D NumPy array. This affects: - `decoder.decode(...).result` - `decoder.decode_async(...).get().result` - tuple unpacking of `DecoderResult` - `decoder.decode_batch(...)[i].result` - iteration over `BatchDecoderResult` Existing code that reads, indexes, iterates, or numerically consumes `result` should generally continue to work. Code that requires `result` to be an actual Python `list`, mutates it with list APIs, compares it with list equality, or serializes it directly should migrate to NumPy semantics or call `.tolist()`. Also updates `decode_batch(...)` to return `BatchDecoderResult` instead of `list[DecoderResult]`. `BatchDecoderResult` exposes vectorized batch fields: - `result`: 2D NumPy array - `converged`: 1D NumPy bool array - `opt_results`: list-like per-shot optional results Compatibility for existing batch consumers is preserved through indexing, slicing, `len(...)`, and iteration: - `batch[i]` materializes a per-shot `DecoderResult` - `batch[a:b]` returns another `BatchDecoderResult` - `for r in batch` continues to yield per-shot `DecoderResult` objects This avoids forcing user code onto the old per-shot Python list extraction path when the natural batch output is already array-shaped. Python decoder plugin authors with custom `decode_batch(...)` overrides must now return `BatchDecoderResult`, not `list[DecoderResult]`. User-facing Sphinx documentation will be updated in a separate docs PR, per project guidance that feature docs should not land before release. ## Runtime / performance impact This change is intended to reduce Python-side result extraction overhead for batch decoding by allowing users to read batch results directly as NumPy arrays instead of iterating through a list of `DecoderResult` objects. Minibench run on commit `246114b0a33e9a58ff3a96e9abb77ac72e945fd1` using the C++ `single_error_lut` decoder. The benchmark uses a CUDA-Q/Stim Steane memory syndrome workload, then splits the fixed shot set into chunks of `batch_size`. Each timed repeat loops over all chunks and measures wall-clock time from calling `decode_batch(...)` through reading the decoded result. Two current-branch access patterns are compared: - `numpy_view`: `decoded = decoder.decode_batch(chunk)`, then read `decoded.result` and `decoded.converged` - `compat_iteration`: `decoded = decoder.decode_batch(chunk)`, then iterate `for r in decoded` and read `r.result` and `r.converged` (this is the backward compatible list access pattern) Each row reports the median of 5 measured repeats. Speedup is `compat_iteration median / numpy_view median`. | shots | batch size | decode_batch calls | result width | NumPy view median ms | compat iteration median ms | speedup | | ---: | ---: | ---: | ---: | ---: | ---: | ---: | | 1000 | 1 | 1000 | 697 | 3.892 | 7.234 | 1.86x | | 1000 | 32 | 32 | 697 | 1.990 | 4.111 | 2.07x | | 1000 | 256 | 4 | 697 | 2.399 | 3.936 | 1.64x | | 3000 | 1 | 3000 | 697 | 11.748 | 21.462 | 1.83x | | 3000 | 32 | 94 | 697 | 6.040 | 12.162 | 2.01x | | 3000 | 256 | 12 | 697 | 5.923 | 11.818 | 1.99x | In this minibench, direct NumPy batch access is about `1.6x-2.1x` faster than compatibility iteration for the measured `single_error_lut` cases. --------- Signed-off-by: Melody Ren <melodyr@nvidia.com>
1 parent d199363 commit 008a3d2

8 files changed

Lines changed: 601 additions & 80 deletions

File tree

libs/qec/python/bindings/py_decoder.cpp

Lines changed: 344 additions & 6 deletions
Large diffs are not rendered by default.

libs/qec/python/cudaq_qec/__init__.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def _ensure_cuda_runtime_loaded():
1818
_ensure_cuda_runtime_loaded()
1919
del _ensure_cuda_runtime_loaded
2020

21+
import functools
22+
2123
from .patch import patch
2224
try:
2325
from ._pycudaqx_qec_the_suffix_matters_cudaq_qec import *
@@ -39,15 +41,51 @@ def _ensure_cuda_runtime_loaded():
3941
__version__ = qecrt.__version__
4042
code = qecrt.code
4143
Code = qecrt.Code
42-
decoder = qecrt.decoder
44+
_native_decoder = qecrt.decoder
4345
Decoder = qecrt.Decoder
46+
47+
48+
def decoder(name):
49+
"""Register a Python class as a decoder plugin under `name`.
50+
51+
Wraps the native registration decorator so that any user-defined
52+
`decode_batch` override is checked at runtime to return a
53+
BatchDecoderResult. Returning a list[DecoderResult] (the pre-batch API)
54+
is no longer supported.
55+
"""
56+
native = _native_decoder(name)
57+
58+
def wrap(cls):
59+
if "decode_batch" in cls.__dict__:
60+
original = cls.decode_batch
61+
cls_name = cls.__name__
62+
63+
@functools.wraps(original)
64+
def checked_decode_batch(self, *args, **kwargs):
65+
result = original(self, *args, **kwargs)
66+
if not isinstance(result, qecrt.BatchDecoderResult):
67+
raise TypeError(
68+
f"{cls_name}.decode_batch must return a "
69+
f"BatchDecoderResult; got "
70+
f"{type(result).__name__}. See BatchDecoderResult "
71+
f"in the cudaq_qec docs for the supported "
72+
f"construction surface.")
73+
return result
74+
75+
cls.decode_batch = checked_decode_batch
76+
return native(cls)
77+
78+
return wrap
79+
80+
4481
TwoQubitDepolarization = qecrt.TwoQubitDepolarization
4582
TwoQubitBitFlip = qecrt.TwoQubitBitFlip
4683
operation = qecrt.operation
4784
get_code = qecrt.get_code
4885
get_available_codes = qecrt.get_available_codes
4986
get_decoder = qecrt.get_decoder
5087
DecoderResult = qecrt.DecoderResult
88+
BatchDecoderResult = qecrt.BatchDecoderResult
5189
DetectorErrorModel = qecrt.DetectorErrorModel
5290
generate_random_bit_flips = qecrt.generate_random_bit_flips
5391
sample_memory_circuit = qecrt.sample_memory_circuit

libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_decoder.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from typing import Any
1212
import cudaq_qec as qec
13+
import numpy as np
1314

1415
import numpy.typing as npt
1516
from quimb.tensor import TensorNetwork
@@ -416,15 +417,17 @@ def decode(
416417
def decode_batch(
417418
self,
418419
syndrome_batch: npt.NDArray[Any],
419-
) -> list["qec.DecoderResult"]:
420+
) -> "qec.BatchDecoderResult":
420421
"""Decode a batch of detection events.
421422
422423
Args:
423424
syndrome_batch (np.ndarray): A numpy array of shape (batch_size, syndrome_length) where each row is a detection event.
424425
425426
Returns:
426-
list[qec.DecoderResult]: list of results for each detection event in the batch.
427-
The probabilities that the logical observable flipped for each syndrome.
427+
qec.BatchDecoderResult: batched results for each detection event in
428+
the batch. The `result` field has shape (batch_size, 1) and
429+
contains the probabilities that the logical observable flipped
430+
for each syndrome.
428431
"""
429432

430433
assert hasattr(self, "noise_model")
@@ -459,16 +462,20 @@ def decode_batch(
459462
device_id=self.contractor_config.device_id,
460463
)
461464

462-
res = []
465+
probabilities = []
463466
for r in range(syndrome_batch.shape[0]):
464-
res.append(qec.DecoderResult())
465-
res[r].converged = True
466-
res[r].result = [
467+
probabilities.append(
467468
float(contraction_value[r, 1] /
468-
(contraction_value[r, 1] + contraction_value[r, 0]))
469-
]
470-
471-
return res
469+
(contraction_value[r, 1] + contraction_value[r, 0])))
470+
471+
# Python `decode_batch` override: construct a BatchDecoderResult
472+
# directly, bypassing the native decoder aggregation path. This is
473+
# the only sanctioned caller of the BatchDecoderResult constructor;
474+
# see its docstring for the supported construction surface.
475+
return qec.BatchDecoderResult(
476+
np.asarray(probabilities, dtype=np.float64).reshape((-1, 1)),
477+
np.ones(syndrome_batch.shape[0], dtype=bool),
478+
)
472479

473480
def optimize_path(
474481
self,

libs/qec/python/tests/test_decoder.py

Lines changed: 160 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,41 @@ def test_decoder_api():
4141
decoder = qec.get_decoder('example_byod', H)
4242
result = decoder.decode_batch(
4343
[create_test_syndrome(), create_test_syndrome()])
44+
assert isinstance(result, qec.BatchDecoderResult)
4445
assert len(result) == 2
45-
for r in result:
46-
assert hasattr(r, 'converged')
47-
assert hasattr(r, 'result')
48-
assert isinstance(r.converged, bool)
49-
assert isinstance(r.result, list)
50-
assert len(r.result) == 10
46+
assert hasattr(result, 'converged')
47+
assert hasattr(result, 'result')
48+
assert hasattr(result, 'opt_results')
49+
assert isinstance(result.converged, np.ndarray)
50+
assert result.converged.shape == (2,)
51+
assert isinstance(result.result, np.ndarray)
52+
assert result.result.shape == (2, 10)
53+
assert len(result.opt_results) == 2
54+
assert all(opt is None for opt in result.opt_results)
55+
first = result[0]
56+
assert isinstance(first, qec.DecoderResult)
57+
assert first.converged == result.converged[0]
58+
np.testing.assert_array_equal(first.result, result.result[0])
59+
assert first.opt_results is None
60+
last = result[-1]
61+
np.testing.assert_array_equal(last.result, result.result[-1])
62+
sliced = result[:1]
63+
assert isinstance(sliced, qec.BatchDecoderResult)
64+
assert sliced.result.shape == (1, 10)
65+
assert sliced.converged.shape == (1,)
66+
assert len(sliced.opt_results) == 1
67+
iterated = list(result)
68+
assert len(iterated) == 2
69+
assert all(isinstance(r, qec.DecoderResult) for r in iterated)
70+
np.testing.assert_array_equal(iterated[1].result, result.result[1])
71+
72+
# Empty batch: shape (0, 0); per-shot width is undefined without input.
73+
empty_result = decoder.decode_batch([])
74+
assert isinstance(empty_result, qec.BatchDecoderResult)
75+
assert empty_result.result.shape == (0, 0)
76+
assert empty_result.converged.shape == (0,)
77+
assert len(empty_result) == 0
78+
assert len(empty_result.opt_results) == 0
5179

5280
# Test decode_async
5381
decoder = qec.get_decoder('example_byod', H)
@@ -59,8 +87,8 @@ def test_decoder_api():
5987
assert hasattr(result, 'converged')
6088
assert hasattr(result, 'result')
6189
assert isinstance(result.converged, bool)
62-
assert isinstance(result.result, list)
63-
assert len(result.result) == 10
90+
assert isinstance(result.result, np.ndarray)
91+
assert result.result.shape == (10,)
6492

6593

6694
def test_decoder_result_structure():
@@ -72,8 +100,8 @@ def test_decoder_result_structure():
72100
assert hasattr(result, 'result')
73101
assert hasattr(result, 'opt_results')
74102
assert isinstance(result.converged, bool)
75-
assert isinstance(result.result, list)
76-
assert len(result.result) == 10
103+
assert isinstance(result.result, np.ndarray)
104+
assert result.result.shape == (10,)
77105

78106
# Test opt_results functionality
79107
assert result.opt_results is None # Default should be None
@@ -85,6 +113,107 @@ def test_decoder_result_structure():
85113
assert result.opt_results is None
86114

87115

116+
def test_batch_decoder_result_constructor():
117+
result = np.zeros((2, 3), dtype=np.float64)
118+
converged = np.array([True, False], dtype=np.bool_)
119+
batch_result = qec.BatchDecoderResult(result, converged)
120+
121+
assert isinstance(batch_result.result, np.ndarray)
122+
assert batch_result.result.shape == (2, 3)
123+
assert batch_result.converged.tolist() == [True, False]
124+
assert batch_result.opt_results == [None, None]
125+
assert batch_result[np.int64(0)].converged is True
126+
assert batch_result[::2].result.shape == (1, 3)
127+
assert np.shares_memory(batch_result[::2].result, batch_result.result)
128+
129+
empty_result = qec.BatchDecoderResult(np.empty((0, 3), dtype=np.float64),
130+
np.array([], dtype=np.bool_))
131+
assert len(empty_result) == 0
132+
assert empty_result.result.shape == (0, 3)
133+
assert empty_result.converged.shape == (0,)
134+
assert list(empty_result) == []
135+
with pytest.raises(IndexError):
136+
empty_result[0]
137+
138+
# Cross-array invariants we still enforce.
139+
with pytest.raises(RuntimeError, match="row count must match"):
140+
qec.BatchDecoderResult(result,
141+
np.array([True, False, True], dtype=np.bool_))
142+
143+
with pytest.raises(RuntimeError, match="opt_results length must match"):
144+
qec.BatchDecoderResult(result, converged, [None])
145+
146+
# Rank is enforced by nanobind, surfaced as TypeError.
147+
with pytest.raises(TypeError):
148+
qec.BatchDecoderResult(np.zeros(3, dtype=np.float64), converged)
149+
150+
# dtype and contiguity are coerced silently, not rejected.
151+
int_result = qec.BatchDecoderResult(np.zeros((2, 3), dtype=np.int32),
152+
converged)
153+
assert int_result.result.dtype == batch_result.result.dtype
154+
assert int_result.result.flags.c_contiguous
155+
156+
f_order = np.asfortranarray(np.zeros((2, 3), dtype=np.float64))
157+
assert not f_order.flags.c_contiguous
158+
f_result = qec.BatchDecoderResult(f_order, converged)
159+
assert f_result.result.flags.c_contiguous
160+
161+
162+
def test_python_decoder_batch_preserves_opt_results():
163+
164+
@qec.decoder("python_opt_results_byod")
165+
class PythonOptResultsDecoder:
166+
167+
def __init__(self, H, **kwargs):
168+
qec.Decoder.__init__(self, H)
169+
self.H = H
170+
171+
def decode(self, syndrome):
172+
res = qec.DecoderResult()
173+
res.converged = True
174+
res.result = np.arange(self.H.shape[1], dtype=np.float64)
175+
res.opt_results = {
176+
"syndrome_weight": int(np.count_nonzero(syndrome)),
177+
"tag": "python"
178+
}
179+
return res
180+
181+
decoder = qec.get_decoder("python_opt_results_byod", H)
182+
batch_result = decoder.decode_batch(
183+
[np.zeros(H.shape[0]), np.ones(H.shape[0])])
184+
185+
assert isinstance(batch_result, qec.BatchDecoderResult)
186+
assert batch_result.result.shape == (2, H.shape[1])
187+
assert batch_result.converged.tolist() == [True, True]
188+
assert batch_result.opt_results[0]["syndrome_weight"] == 0
189+
assert batch_result.opt_results[1]["syndrome_weight"] == H.shape[0]
190+
assert batch_result[1].opt_results["tag"] == "python"
191+
192+
193+
def test_python_decoder_batch_override_must_return_batch_decoder_result():
194+
195+
@qec.decoder("python_bad_batch_byod")
196+
class PythonBadBatchDecoder:
197+
198+
def __init__(self, H, **kwargs):
199+
qec.Decoder.__init__(self, H)
200+
self.H = H
201+
202+
def decode(self, syndrome):
203+
res = qec.DecoderResult()
204+
res.converged = True
205+
res.result = np.zeros(self.H.shape[1], dtype=np.float64)
206+
return res
207+
208+
def decode_batch(self, syndromes):
209+
# Pre-batch return shape; the decorator should reject this.
210+
return [self.decode(s) for s in syndromes]
211+
212+
decoder = qec.get_decoder("python_bad_batch_byod", H)
213+
with pytest.raises(TypeError, match="must return a BatchDecoderResult"):
214+
decoder.decode_batch([np.zeros(H.shape[0]), np.ones(H.shape[0])])
215+
216+
88217
def test_decoder_plugin_initialization():
89218
decoder = qec.get_decoder('single_error_lut_example', H)
90219
assert decoder is not None
@@ -133,16 +262,16 @@ def test_decoder_plugin_result_structure():
133262
assert hasattr(result, 'converged')
134263
assert hasattr(result, 'result')
135264
assert isinstance(result.converged, bool)
136-
assert isinstance(result.result, list)
265+
assert isinstance(result.result, np.ndarray)
137266

138267

139268
def test_decoder_result_values():
140269
decoder = qec.get_decoder('example_byod', H)
141270
result = decoder.decode(create_test_syndrome())
142271

143272
assert result.converged is True
144-
assert all(isinstance(x, float) for x in result.result)
145-
assert all(0 <= x <= 1 for x in result.result)
273+
assert isinstance(result.result, np.ndarray)
274+
assert np.all((0 <= result.result) & (result.result <= 1))
146275

147276

148277
@pytest.mark.parametrize("matrix_shape,syndrome_size", [((5, 10), 5),
@@ -158,8 +287,8 @@ def test_decoder_different_matrix_sizes(matrix_shape, syndrome_size):
158287

159288
assert len(result) == syndrome_size
160289
assert convergence is True
161-
assert all(isinstance(x, float) for x in result)
162-
assert all(0 <= x <= 1 for x in result)
290+
assert isinstance(result, np.ndarray)
291+
assert np.all((0 <= result) & (result <= 1))
163292

164293

165294
# FIXME add this back
@@ -187,7 +316,7 @@ def test_decoder_reproducibility():
187316
np.random.seed(42)
188317
convergence2, result2, opt2 = decoder.decode(create_test_syndrome())
189318

190-
assert result1 == result2
319+
np.testing.assert_array_equal(result1, result2)
191320
assert convergence1 == convergence2
192321

193322

@@ -338,6 +467,19 @@ def test_single_error_lut_opt_results():
338467
assert "syndrome_weight" in result.opt_results
339468
assert "decoding_time" not in result.opt_results # Was set to False
340469

470+
batch_result = decoder.decode_batch(
471+
[create_test_syndrome(), create_test_syndrome()])
472+
assert isinstance(batch_result.result, np.ndarray)
473+
assert batch_result.result.shape == (2, H.shape[1])
474+
assert isinstance(batch_result.converged, np.ndarray)
475+
assert batch_result.converged.shape == (2,)
476+
assert len(batch_result.opt_results) == 2
477+
for opt_results in batch_result.opt_results:
478+
assert opt_results is not None
479+
assert "error_probability" in opt_results
480+
assert "syndrome_weight" in opt_results
481+
assert "decoding_time" not in opt_results
482+
341483

342484
def test_decoder_pymatching_results():
343485
pcm = qec.generate_random_pcm(n_rounds=2,
@@ -353,8 +495,8 @@ def test_decoder_pymatching_results():
353495
decoder = qec.get_decoder('pymatching', pcm)
354496
result = decoder.decode(syndrome)
355497
assert result.converged is True
356-
assert all(isinstance(x, float) for x in result.result)
357-
assert all(0 <= x <= 1 for x in result.result)
498+
assert isinstance(result.result, np.ndarray)
499+
assert np.all((0 <= result.result) & (result.result <= 1))
358500
actual_errors = np.zeros(pcm.shape[1], dtype=np.uint8)
359501
actual_errors[columns] = 1
360502
assert np.array_equal(result.result, actual_errors)

0 commit comments

Comments
 (0)