Skip to content

Commit d111f9a

Browse files
docs(qec): document tensor network noise-learning decoders
Adds Sphinx API narrative and decoder docs for noise-learning integration, split from the integration PR for focused review. Signed-off-by: vedika-saravanan <vsaravanan@nvidia.com>
1 parent 08a5fe5 commit d111f9a

3 files changed

Lines changed: 228 additions & 1 deletion

File tree

docs/sphinx/api/qec/tensor_network_decoder_api.rst

Lines changed: 182 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,4 +93,185 @@
9393

9494
:param optimize: Optimization options or None
9595
:param batch_size: (int, optional) Batch size for optimization (default: -1, no batching)
96-
:returns: Optimizer info object
96+
:returns: Optimizer info object
97+
98+
.. class:: cudaq_qec.plugins.decoders.tensor_network_decoder.NMOptimizer
99+
100+
Differentiable noise-model optimizer built on top of :class:`TensorNetworkDecoder`.
101+
102+
Fits a factorised per-error noise model to a syndrome dataset by
103+
backpropagating through a torch-backed tensor-network contraction.
104+
The noise probabilities are maintained as ``torch`` tensors with
105+
``requires_grad=True`` so they can be updated with any ``torch.optim``
106+
optimizer.
107+
108+
Requires Python 3.11 or higher and the same optional dependencies as
109+
:class:`TensorNetworkDecoder` (``pip install cudaq-qec[tensor-network-decoder]``).
110+
PyTorch must also be installed.
111+
112+
.. note::
113+
Quick-start example (logit-space training; the loss has no ``log``
114+
guard, so direct probability training requires per-step clamping
115+
into ``[eps, 1 - eps]``)::
116+
117+
import numpy as np
118+
import torch
119+
from cudaq_qec.plugins.decoders.tensor_network_decoder import (
120+
NMOptimizer, make_compiled_step,
121+
)
122+
123+
H = np.array([[1, 1, 0], [0, 1, 1]], dtype=np.float64)
124+
logical = np.array([[1, 0, 1]], dtype=np.float64)
125+
priors = [0.1, 0.2, 0.3]
126+
127+
opt = NMOptimizer(H, logical, priors, syndrome_data, obs_flips,
128+
dtype="float64")
129+
logits = torch.logit(opt.noise_params[0].detach()).requires_grad_()
130+
adam = torch.optim.Adam([logits], lr=0.01)
131+
step = make_compiled_step(opt, logits, adam)
132+
for _ in range(100):
133+
step()
134+
135+
:param H: Parity check matrix (numpy.ndarray), shape (num_checks, num_errors)
136+
:param logical_obs: Logical observable matrix (numpy.ndarray), shape (1, num_errors)
137+
:param noise_model: Initial per-error probabilities, list of floats in (0, 1).
138+
Values outside ``[eps, 1 - eps]`` are clamped at
139+
construction with a ``UserWarning``; non-finite values
140+
raise ``ValueError``. ``eps`` is ``1e-12`` for
141+
``"float64"`` and ``1e-6`` for ``"float32"``.
142+
:param syndrome_data: Observed syndromes, numpy.ndarray of shape (num_shots, num_checks)
143+
:param observable_flips: Observed logical flips, bool array of length num_shots
144+
:param check_inds: (optional) List of check index names; defaults track the parent decoder.
145+
:param error_inds: (optional) List of error index names; defaults track the parent decoder.
146+
:param logical_inds: (optional) List of logical index names; defaults track the parent decoder.
147+
:param logical_tags: (optional) List of logical tags; defaults track the parent decoder.
148+
:param dtype: (str, optional) ``"float32"`` (default) or ``"float64"``;
149+
other values raise ``ValueError``.
150+
:param device: (str, optional) Torch device, e.g. ``"cpu"`` or ``"cuda"`` (default: ``"cuda"``)
151+
:param compile: (bool, optional, keyword-only) If ``True``, wrap the forward
152+
and loss in :func:`torch.compile`. Most useful with
153+
``execute="codegen"``. Defaults to ``False``.
154+
:param execute: (str, optional, keyword-only) Forward backend. ``"codegen"``
155+
(default) partial-evaluates the contraction path into a flat
156+
Python function with named locals; ``"unrolled"`` keeps an
157+
interpretive einsum list; ``"opt_einsum"`` dispatches via
158+
:func:`opt_einsum.contract_expression`.
159+
:param compile_mode: (str, optional, keyword-only) Forwarded to
160+
:func:`torch.compile` (e.g. ``"reduce-overhead"``,
161+
``"default"``); ignored when ``compile=False``.
162+
:param dynamic_syndromes: (bool, optional, keyword-only) If ``True``
163+
(default), syndromes are runtime arguments to the
164+
compiled forward, so :meth:`update_dataset` reuses
165+
the codegen/``torch.compile`` artifact when shapes
166+
are unchanged. ``False`` bakes syndromes into the
167+
closure -- faster per call but every
168+
:meth:`update_dataset` rebuilds the graph. Only
169+
affects ``execute="codegen"``.
170+
171+
**Attributes**
172+
173+
.. attribute:: noise_params
174+
175+
``list[torch.Tensor]`` — the learnable noise-probability tensors; pass
176+
directly to a ``torch.optim`` optimizer.
177+
178+
.. attribute:: torch_device
179+
180+
``torch.device`` derived from the ``device`` constructor argument.
181+
Read-only.
182+
183+
.. attribute:: observable_flips
184+
185+
Bool ``torch.Tensor`` of logical flip outcomes for the current
186+
syndrome batch. Assigning a new value also rebuilds the fused
187+
loss closure (the observable indices are baked into the codegen);
188+
prefer :meth:`update_dataset` when swapping syndromes and flips
189+
together.
190+
191+
**Methods**
192+
193+
.. method:: current_syndrome_args()
194+
195+
Return the syndrome argument expected by the callable from
196+
:meth:`loss_fn`: the live tuple when ``dynamic_syndromes=True``,
197+
or ``()`` for static codegen (syndromes are closure-baked).
198+
Re-fetch each step so an intervening :meth:`update_dataset` is
199+
reflected.
200+
201+
:returns: ``tuple[torch.Tensor, ...]``
202+
203+
.. method:: cross_entropy_loss()
204+
205+
Compute the cross-entropy loss between the predicted logical-flip
206+
probabilities and the observed ``observable_flips``.
207+
208+
:returns: Scalar ``torch.Tensor`` (differentiable).
209+
210+
.. method:: decoder_prediction()
211+
212+
Run the forward pass and return per-shot probabilities.
213+
214+
:returns: ``torch.Tensor`` of shape ``(num_shots, 2)`` where column 1
215+
is ``P(logical flip | syndrome)``.
216+
217+
.. method:: logical_error_rate()
218+
219+
Fraction of shots where ``argmax`` of :meth:`decoder_prediction`
220+
disagrees with :attr:`observable_flips`. Not differentiable
221+
(runs under :func:`torch.no_grad`).
222+
223+
:returns: ``float`` in ``[0, 1]``.
224+
225+
.. method:: loss_fn(from_logits=True)
226+
227+
Return a compiled callable ``fn(params, syndrome_tuple) -> loss``
228+
suitable for use with external optimizers or ``torch.compile``.
229+
230+
:param from_logits: If ``True`` (default), ``params`` are interpreted
231+
as logits and passed through ``sigmoid`` before
232+
contraction. If ``False``, ``params`` are
233+
interpreted as probabilities already in ``[0, 1]``.
234+
:returns: Compiled loss function.
235+
236+
.. method:: optimize_path(optimize=None, batch_size=-1)
237+
238+
Cache a contraction path via quimb / opt_einsum and rebuild the
239+
compiled forward. Pass e.g. ``cotengra.HyperOptimizer()`` to run a
240+
more expensive path search; ``None`` falls back to ``"auto"``.
241+
242+
:param optimize: Optimization options (e.g. a ``cotengra.HyperOptimizer``)
243+
or ``None``.
244+
:param batch_size: Accepted for signature compatibility; ignored.
245+
:returns: Contraction info object.
246+
247+
.. method:: update_dataset(syndrome_data, observable_flips, enforce_shape=True)
248+
249+
Swap in a new syndrome batch without rebuilding the tensor network.
250+
If ``dynamic_syndromes=True`` and the batch size is unchanged, the
251+
compiled contraction path is reused; a shape change triggers a full
252+
rebuild.
253+
254+
:param syndrome_data: numpy.ndarray of shape (num_shots, num_checks)
255+
:param observable_flips: bool array of length num_shots
256+
:param enforce_shape: (bool, optional, default ``True``) Assert
257+
per-tensor shapes match the existing layout
258+
before patching in place. A batch-size change
259+
triggers a full rebuild regardless.
260+
261+
.. function:: cudaq_qec.plugins.decoders.tensor_network_decoder.make_compiled_step(optimizer, logits, torch_optimizer)
262+
263+
Build a no-arg callable that runs one Adam step and returns the loss.
264+
265+
The returned ``step()`` callable zeros gradients, evaluates the
266+
optimizer's fused ``loss_fn(from_logits=True)`` (sigmoid + contraction +
267+
cross-entropy), backpropagates, and steps ``torch_optimizer``. Intended
268+
for training in logit space; pair with :class:`NMOptimizer` constructed
269+
with ``compile=True`` for a ``torch.compile``-d variant.
270+
271+
:param optimizer: An :class:`NMOptimizer` instance providing the fused
272+
inner loss.
273+
:param logits: Trainable 1-D ``torch.Tensor`` of length
274+
``len(optimizer.error_inds)`` with ``requires_grad=True``.
275+
:param torch_optimizer: A ``torch.optim`` instance owning ``logits``.
276+
:returns: A no-arg callable that performs one optimization step and
277+
returns the scalar loss as a ``torch.Tensor``.

docs/sphinx/components/qec/introduction.rst

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,27 @@ The decoder returns the probability that the logical observable has flipped for
899899
that this GPU will not be supported by the Tensor Network Decoder when
900900
CUDA-Q 0.5.0 is released.
901901

902+
Learning the Noise Model from Data
903+
""""""""""""""""""""""""""""""""""
904+
905+
When the true per-error noise rates are unknown (typical of real hardware),
906+
the Tensor Network Decoder ships with ``NMOptimizer``, a differentiable
907+
extension that **fits the noise model directly from observed syndromes and
908+
logical-flip outcomes**. Noise probabilities are held as PyTorch tensors
909+
with ``requires_grad=True``; backpropagating through the tensor-network
910+
contraction yields gradients that any ``torch.optim`` optimizer (Adam, SGD,
911+
etc.) can update. Starting from a uniform initial prior and a few hundred
912+
Adam steps is usually enough to recover the per-error rates and beat a
913+
static-uniform baseline on a held-out batch.
914+
915+
This is offline -- training happens once on a representative syndrome
916+
dataset, and the learned probabilities can then be used as a standard
917+
static noise model for batch decoding. See
918+
:ref:`tensor_network_decoder_api_python` for the ``NMOptimizer`` API and
919+
the *Learning Noise Models with NMOptimizer* example in
920+
:doc:`../../examples_rst/qec/decoders` for a runnable end-to-end demo on a
921+
Stim repetition-code circuit.
922+
902923

903924
Sliding Window Decoder
904925
^^^^^^^^^^^^^^^^^^^^^^

docs/sphinx/examples_rst/qec/decoders.rst

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,31 @@ See Also:
136136

137137
- ``cudaq_qec.plugins.decoders.tensor_network_decoder``
138138

139+
Learning Noise Models with NMOptimizer
140+
+++++++++++++++++++++++++++++++++++++++
141+
142+
:class:`~cudaq_qec.plugins.decoders.tensor_network_decoder.NMOptimizer` extends
143+
the Tensor Network Decoder with differentiable noise probabilities. Given a
144+
batch of observed syndromes and logical-flip outcomes, it fits per-error noise
145+
rates by backpropagating through the tensor-network contraction using PyTorch.
146+
147+
The following example builds a distance-3 repetition-code circuit with
148+
**asymmetric** noise (data-qubit depolarization is 10x measurement-flip
149+
probability), samples syndromes from Stim, and trains
150+
:class:`NMOptimizer` from a uniform initial prior with 300 Adam steps in
151+
logit space. It then compares the **logical error rate (LER)** of the
152+
learned noise model against a static uniform-prior baseline on a 20k-shot
153+
held-out batch — demonstrating that fitting per-error rates from data
154+
decodes meaningfully better than assuming uniform noise:
155+
156+
.. literalinclude:: ../../examples/qec/python/noise_learning.py
157+
:language: python
158+
:start-after: [Begin Documentation]
159+
160+
See Also:
161+
162+
- :ref:`tensor_network_decoder_api_python`
163+
139164
.. _deploying-ai-decoders:
140165

141166
Deploying AI Decoders with TensorRT

0 commit comments

Comments
 (0)