|
93 | 93 |
|
94 | 94 | :param optimize: Optimization options or None |
95 | 95 | :param batch_size: (int, optional) Batch size for optimization (default: -1, no batching) |
96 | | - :returns: Optimizer info object |
| 96 | + :returns: Optimizer info object |
| 97 | + |
| 98 | +.. class:: cudaq_qec.plugins.decoders.tensor_network_decoder.NMOptimizer |
| 99 | + |
| 100 | + Differentiable noise-model optimizer built on top of :class:`TensorNetworkDecoder`. |
| 101 | + |
| 102 | + Fits a factorised per-error noise model to a syndrome dataset by |
| 103 | + backpropagating through a torch-backed tensor-network contraction. |
| 104 | + The noise probabilities are maintained as ``torch`` tensors with |
| 105 | + ``requires_grad=True`` so they can be updated with any ``torch.optim`` |
| 106 | + optimizer. |
| 107 | + |
| 108 | + Requires Python 3.11 or higher and the same optional dependencies as |
| 109 | + :class:`TensorNetworkDecoder` (``pip install cudaq-qec[tensor-network-decoder]``). |
| 110 | + PyTorch must also be installed. |
| 111 | + |
| 112 | + .. note:: |
| 113 | + Quick-start example (logit-space training; the loss has no ``log`` |
| 114 | + guard, so direct probability training requires per-step clamping |
| 115 | + into ``[eps, 1 - eps]``):: |
| 116 | + |
| 117 | + import numpy as np |
| 118 | + import torch |
| 119 | + from cudaq_qec.plugins.decoders.tensor_network_decoder import ( |
| 120 | + NMOptimizer, make_compiled_step, |
| 121 | + ) |
| 122 | + |
| 123 | + H = np.array([[1, 1, 0], [0, 1, 1]], dtype=np.float64) |
| 124 | + logical = np.array([[1, 0, 1]], dtype=np.float64) |
| 125 | + priors = [0.1, 0.2, 0.3] |
| 126 | + |
| 127 | + opt = NMOptimizer(H, logical, priors, syndrome_data, obs_flips, |
| 128 | + dtype="float64") |
| 129 | + logits = torch.logit(opt.noise_params[0].detach()).requires_grad_() |
| 130 | + adam = torch.optim.Adam([logits], lr=0.01) |
| 131 | + step = make_compiled_step(opt, logits, adam) |
| 132 | + for _ in range(100): |
| 133 | + step() |
| 134 | + |
| 135 | + :param H: Parity check matrix (numpy.ndarray), shape (num_checks, num_errors) |
| 136 | + :param logical_obs: Logical observable matrix (numpy.ndarray), shape (1, num_errors) |
| 137 | + :param noise_model: Initial per-error probabilities, list of floats in (0, 1). |
| 138 | + Values outside ``[eps, 1 - eps]`` are clamped at |
| 139 | + construction with a ``UserWarning``; non-finite values |
| 140 | + raise ``ValueError``. ``eps`` is ``1e-12`` for |
| 141 | + ``"float64"`` and ``1e-6`` for ``"float32"``. |
| 142 | + :param syndrome_data: Observed syndromes, numpy.ndarray of shape (num_shots, num_checks) |
| 143 | + :param observable_flips: Observed logical flips, bool array of length num_shots |
| 144 | + :param check_inds: (optional) List of check index names; defaults track the parent decoder. |
| 145 | + :param error_inds: (optional) List of error index names; defaults track the parent decoder. |
| 146 | + :param logical_inds: (optional) List of logical index names; defaults track the parent decoder. |
| 147 | + :param logical_tags: (optional) List of logical tags; defaults track the parent decoder. |
| 148 | + :param dtype: (str, optional) ``"float32"`` (default) or ``"float64"``; |
| 149 | + other values raise ``ValueError``. |
| 150 | + :param device: (str, optional) Torch device, e.g. ``"cpu"`` or ``"cuda"`` (default: ``"cuda"``) |
| 151 | + :param compile: (bool, optional, keyword-only) If ``True``, wrap the forward |
| 152 | + and loss in :func:`torch.compile`. Most useful with |
| 153 | + ``execute="codegen"``. Defaults to ``False``. |
| 154 | + :param execute: (str, optional, keyword-only) Forward backend. ``"codegen"`` |
| 155 | + (default) partial-evaluates the contraction path into a flat |
| 156 | + Python function with named locals; ``"unrolled"`` keeps an |
| 157 | + interpretive einsum list; ``"opt_einsum"`` dispatches via |
| 158 | + :func:`opt_einsum.contract_expression`. |
| 159 | + :param compile_mode: (str, optional, keyword-only) Forwarded to |
| 160 | + :func:`torch.compile` (e.g. ``"reduce-overhead"``, |
| 161 | + ``"default"``); ignored when ``compile=False``. |
| 162 | + :param dynamic_syndromes: (bool, optional, keyword-only) If ``True`` |
| 163 | + (default), syndromes are runtime arguments to the |
| 164 | + compiled forward, so :meth:`update_dataset` reuses |
| 165 | + the codegen/``torch.compile`` artifact when shapes |
| 166 | + are unchanged. ``False`` bakes syndromes into the |
| 167 | + closure -- faster per call but every |
| 168 | + :meth:`update_dataset` rebuilds the graph. Only |
| 169 | + affects ``execute="codegen"``. |
| 170 | + |
| 171 | + **Attributes** |
| 172 | + |
| 173 | + .. attribute:: noise_params |
| 174 | + |
| 175 | + ``list[torch.Tensor]`` — the learnable noise-probability tensors; pass |
| 176 | + directly to a ``torch.optim`` optimizer. |
| 177 | + |
| 178 | + .. attribute:: torch_device |
| 179 | + |
| 180 | + ``torch.device`` derived from the ``device`` constructor argument. |
| 181 | + Read-only. |
| 182 | + |
| 183 | + .. attribute:: observable_flips |
| 184 | + |
| 185 | + Bool ``torch.Tensor`` of logical flip outcomes for the current |
| 186 | + syndrome batch. Assigning a new value also rebuilds the fused |
| 187 | + loss closure (the observable indices are baked into the codegen); |
| 188 | + prefer :meth:`update_dataset` when swapping syndromes and flips |
| 189 | + together. |
| 190 | + |
| 191 | + **Methods** |
| 192 | + |
| 193 | + .. method:: current_syndrome_args() |
| 194 | + |
| 195 | + Return the syndrome argument expected by the callable from |
| 196 | + :meth:`loss_fn`: the live tuple when ``dynamic_syndromes=True``, |
| 197 | + or ``()`` for static codegen (syndromes are closure-baked). |
| 198 | + Re-fetch each step so an intervening :meth:`update_dataset` is |
| 199 | + reflected. |
| 200 | + |
| 201 | + :returns: ``tuple[torch.Tensor, ...]`` |
| 202 | + |
| 203 | + .. method:: cross_entropy_loss() |
| 204 | + |
| 205 | + Compute the cross-entropy loss between the predicted logical-flip |
| 206 | + probabilities and the observed ``observable_flips``. |
| 207 | + |
| 208 | + :returns: Scalar ``torch.Tensor`` (differentiable). |
| 209 | + |
| 210 | + .. method:: decoder_prediction() |
| 211 | + |
| 212 | + Run the forward pass and return per-shot probabilities. |
| 213 | + |
| 214 | + :returns: ``torch.Tensor`` of shape ``(num_shots, 2)`` where column 1 |
| 215 | + is ``P(logical flip | syndrome)``. |
| 216 | + |
| 217 | + .. method:: logical_error_rate() |
| 218 | + |
| 219 | + Fraction of shots where ``argmax`` of :meth:`decoder_prediction` |
| 220 | + disagrees with :attr:`observable_flips`. Not differentiable |
| 221 | + (runs under :func:`torch.no_grad`). |
| 222 | + |
| 223 | + :returns: ``float`` in ``[0, 1]``. |
| 224 | + |
| 225 | + .. method:: loss_fn(from_logits=True) |
| 226 | + |
| 227 | + Return a compiled callable ``fn(params, syndrome_tuple) -> loss`` |
| 228 | + suitable for use with external optimizers or ``torch.compile``. |
| 229 | + |
| 230 | + :param from_logits: If ``True`` (default), ``params`` are interpreted |
| 231 | + as logits and passed through ``sigmoid`` before |
| 232 | + contraction. If ``False``, ``params`` are |
| 233 | + interpreted as probabilities already in ``[0, 1]``. |
| 234 | + :returns: Compiled loss function. |
| 235 | + |
| 236 | + .. method:: optimize_path(optimize=None, batch_size=-1) |
| 237 | + |
| 238 | + Cache a contraction path via quimb / opt_einsum and rebuild the |
| 239 | + compiled forward. Pass e.g. ``cotengra.HyperOptimizer()`` to run a |
| 240 | + more expensive path search; ``None`` falls back to ``"auto"``. |
| 241 | + |
| 242 | + :param optimize: Optimization options (e.g. a ``cotengra.HyperOptimizer``) |
| 243 | + or ``None``. |
| 244 | + :param batch_size: Accepted for signature compatibility; ignored. |
| 245 | + :returns: Contraction info object. |
| 246 | + |
| 247 | + .. method:: update_dataset(syndrome_data, observable_flips, enforce_shape=True) |
| 248 | + |
| 249 | + Swap in a new syndrome batch without rebuilding the tensor network. |
| 250 | + If ``dynamic_syndromes=True`` and the batch size is unchanged, the |
| 251 | + compiled contraction path is reused; a shape change triggers a full |
| 252 | + rebuild. |
| 253 | + |
| 254 | + :param syndrome_data: numpy.ndarray of shape (num_shots, num_checks) |
| 255 | + :param observable_flips: bool array of length num_shots |
| 256 | + :param enforce_shape: (bool, optional, default ``True``) Assert |
| 257 | + per-tensor shapes match the existing layout |
| 258 | + before patching in place. A batch-size change |
| 259 | + triggers a full rebuild regardless. |
| 260 | + |
| 261 | +.. function:: cudaq_qec.plugins.decoders.tensor_network_decoder.make_compiled_step(optimizer, logits, torch_optimizer) |
| 262 | + |
| 263 | + Build a no-arg callable that runs one Adam step and returns the loss. |
| 264 | + |
| 265 | + The returned ``step()`` callable zeros gradients, evaluates the |
| 266 | + optimizer's fused ``loss_fn(from_logits=True)`` (sigmoid + contraction + |
| 267 | + cross-entropy), backpropagates, and steps ``torch_optimizer``. Intended |
| 268 | + for training in logit space; pair with :class:`NMOptimizer` constructed |
| 269 | + with ``compile=True`` for a ``torch.compile``-d variant. |
| 270 | + |
| 271 | + :param optimizer: An :class:`NMOptimizer` instance providing the fused |
| 272 | + inner loss. |
| 273 | + :param logits: Trainable 1-D ``torch.Tensor`` of length |
| 274 | + ``len(optimizer.error_inds)`` with ``requires_grad=True``. |
| 275 | + :param torch_optimizer: A ``torch.optim`` instance owning ``logits``. |
| 276 | + :returns: A no-arg callable that performs one optimization step and |
| 277 | + returns the scalar loss as a ``torch.Tensor``. |
0 commit comments