diff --git a/src/BUILD b/src/BUILD index ebac6a5..cb021d4 100644 --- a/src/BUILD +++ b/src/BUILD @@ -140,6 +140,32 @@ cc_library( ], ) +cc_library( + name = "libtesseract_trellis", + srcs = ["tesseract_trellis.cc"], + hdrs = ["tesseract_trellis.h"], + copts = OPT_COPTS, + linkopts = OPT_LINKOPTS, + deps = [ + ":libcommon", + ":libutils", + "@stim//:stim_lib", + ], +) + + +cc_library( + name = "libtesseract_ftl", + srcs = ["tesseract_ftl.cc"], + hdrs = ["tesseract_ftl.h"], + copts = OPT_COPTS, + linkopts = OPT_LINKOPTS, + deps = [ + ":libtesseract", + "@highs", + ], +) + cc_binary( name = "tesseract", srcs = ["tesseract_main.cc"], @@ -153,6 +179,32 @@ cc_binary( ], ) +cc_binary( + name = "tesseract_trellis", + srcs = ["tesseract_trellis_main.cc"], + copts = OPT_COPTS, + linkopts = OPT_LINKOPTS, + deps = [ + ":libtesseract_trellis", + "@argparse", + "@nlohmann_json//:json", + "@stim//:stim_lib", + ], +) + +cc_binary( + name = "tesseract_ftl", + srcs = ["tesseract_ftl_main.cc"], + copts = OPT_COPTS, + linkopts = OPT_LINKOPTS, + deps = [ + ":libtesseract_ftl", + "@argparse", + "@nlohmann_json//:json", + "@stim//:stim_lib", + ], +) + cc_test( name = "tesseract_tests", timeout = "eternal", diff --git a/src/py/astar/astar_prototype.py b/src/py/astar/astar_prototype.py new file mode 100644 index 0000000..9359813 --- /dev/null +++ b/src/py/astar/astar_prototype.py @@ -0,0 +1,611 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder with plain detcost or optimal singleton detcost. + +The default heuristic matches the original prototype's plain detector-wise +heuristic. Passing --opt-singleton-detcost switches to the exact optimal +singleton lower bound, solved as a small LP over the currently active +residual detectors. + +Notes: + * The search still uses the precedence-based tree pruning from the + prototype. + * By default, the heuristic ignores precedence-blocked errors in order to + preserve the original prototype's behavior. Use + --respect-blocked-errors-in-heuristic to exclude blocked errors from the + heuristic as well. + * The optimal singleton heuristic requires SciPy (``scipy.optimize.linprog``). +""" + +from __future__ import annotations + +import argparse +import heapq +import math +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +from scipy.optimize import linprog +from scipy.sparse import csr_matrix + +INF = float("inf") + + +@dataclass(frozen=True) +class ErrorRecord: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class SearchState: + errs: np.ndarray + blocked_errs: np.ndarray + dets: np.ndarray + det_counts: np.ndarray + g_cost: float + + +@dataclass +class DecodeResult: + success: bool + errs: np.ndarray + residual_dets: np.ndarray + cost: float + nodes_pushed: int + nodes_popped: int + heuristic_calls: int + lp_calls: int + elapsed_seconds: float + + +class AStarPrototypeDecoder: + def __init__( + self, + errors: Sequence[ErrorRecord], + num_detectors: int, + *, + use_opt_singleton_detcost: bool = False, + respect_blocked_errors_in_heuristic: bool = False, + verbose_search: bool = False, + ) -> None: + self.errors = list(errors) + self.num_detectors = int(num_detectors) + self.num_errors = len(self.errors) + self.use_opt_singleton_detcost = use_opt_singleton_detcost + self.respect_blocked_errors_in_heuristic = respect_blocked_errors_in_heuristic + self.verbose_search = verbose_search + + if self.use_opt_singleton_detcost and linprog is None: + raise RuntimeError( + "--opt-singleton-detcost requires scipy. Install scipy and rerun." + ) + + self.ecosts = np.array([err.likelihood_cost for err in self.errors], dtype=np.float64) + self.edets: List[np.ndarray] = [ + np.array(err.detectors, dtype=np.int32) for err in self.errors + ] + self.eobs: List[np.ndarray] = [ + np.array(err.observables, dtype=np.int32) for err in self.errors + ] + + d2e_lists: List[List[int]] = [[] for _ in range(self.num_detectors)] + for ei, dets in enumerate(self.edets): + for d in dets: + d2e_lists[int(d)].append(ei) + self.d2e: List[np.ndarray] = [np.array(v, dtype=np.int32) for v in d2e_lists] + + self.heuristic_calls = 0 + self.lp_calls = 0 + + @property + def heuristic_name(self) -> str: + if self.use_opt_singleton_detcost: + return "opt-singleton-detcost" + return "plain-detcost" + + def _available_errors(self, errs: np.ndarray, blocked_errs: np.ndarray) -> np.ndarray: + available = ~errs + if self.respect_blocked_errors_in_heuristic: + available &= ~blocked_errs + return available + + def _plain_detcost_heuristic( + self, + available_errs: np.ndarray, + dets: np.ndarray, + det_counts: np.ndarray, + ) -> float: + total = 0.0 + for d in np.flatnonzero(dets): + best = INF + for ei in self.d2e[int(d)]: + ei = int(ei) + if not available_errs[ei]: + continue + count = int(det_counts[ei]) + assert count > 0 + value = self.ecosts[ei] / count + if value < best: + best = value + if math.isinf(best): + return INF + total += best + return total + + def _opt_singleton_detcost_heuristic( + self, + available_errs: np.ndarray, + dets: np.ndarray, + det_counts: np.ndarray, + ) -> float: + active_dets = np.flatnonzero(dets) + if active_dets.size == 0: + return 0.0 + + det_to_var = {int(d): i for i, d in enumerate(active_dets.tolist())} + support_to_weight: Dict[Tuple[int, ...], float] = {} + covered = np.zeros(active_dets.size, dtype=bool) + + for ei in np.flatnonzero(available_errs): + ei = int(ei) + if int(det_counts[ei]) == 0: + continue + support = tuple(det_to_var[int(d)] for d in self.edets[ei] if dets[int(d)]) + if not support: + continue + for var in support: + covered[var] = True + weight = float(self.ecosts[ei]) + old = support_to_weight.get(support) + if old is None or weight < old: + support_to_weight[support] = weight + + if not np.all(covered): + return INF + + num_vars = active_dets.size + supports = list(support_to_weight.keys()) + weights = np.array([support_to_weight[s] for s in supports], dtype=np.float64) + + row_indices: List[int] = [] + col_indices: List[int] = [] + data: List[float] = [] + for row, support in enumerate(supports): + row_indices.extend([row] * len(support)) + col_indices.extend(support) + data.extend([1.0] * len(support)) + + + a_ub = csr_matrix( + (data, (row_indices, col_indices)), + shape=(len(supports), num_vars), + dtype=np.float64, + ) + + self.lp_calls += 1 + result = linprog( + c=-np.ones(num_vars, dtype=np.float64), + A_ub=a_ub, + b_ub=weights, + bounds=[(0.0, None)] * num_vars, + method="highs", + ) + if result.status == 0: + return max(0.0, float(-result.fun)) + if result.status in {2, 3}: # infeasible or unbounded + return INF + raise RuntimeError(f"linprog failed with status={result.status}: {result.message}") + + def heuristic_cost( + self, + errs: np.ndarray, + blocked_errs: np.ndarray, + dets: np.ndarray, + det_counts: np.ndarray, + ) -> float: + self.heuristic_calls += 1 + available = self._available_errors(errs, blocked_errs) + if self.use_opt_singleton_detcost: + return self._opt_singleton_detcost_heuristic(available, dets, det_counts) + return self._plain_detcost_heuristic(available, dets, det_counts) + + def decode(self, shot_dets: np.ndarray, det_beam: float = INF) -> DecodeResult: + start_time = time.perf_counter() + self.heuristic_calls = 0 + self.lp_calls = 0 + + dets0 = np.array(shot_dets, dtype=bool, copy=True) + errs0 = np.zeros(self.num_errors, dtype=bool) + blocked0 = np.zeros(self.num_errors, dtype=bool) + det_counts0 = np.zeros(self.num_errors, dtype=np.uint16) + for d in np.flatnonzero(dets0): + for ei in self.d2e[int(d)]: + det_counts0[int(ei)] += 1 + + h0 = self.heuristic_cost(errs0, blocked0, dets0, det_counts0) + if math.isinf(h0): + return DecodeResult( + success=False, + errs=errs0, + residual_dets=dets0, + cost=INF, + nodes_pushed=1, + nodes_popped=0, + heuristic_calls=self.heuristic_calls, + lp_calls=self.lp_calls, + elapsed_seconds=time.perf_counter() - start_time, + ) + + next_node_id = 1 + heap: List[Tuple[float, int, int]] = [(h0, int(dets0.sum()), 0)] + node_data: Dict[int, SearchState] = { + 0: SearchState( + errs=errs0, + blocked_errs=blocked0, + dets=dets0, + det_counts=det_counts0, + g_cost=0.0, + ) + } + + nodes_pushed = 1 + nodes_popped = 0 + min_num_dets = int(dets0.sum()) + + while heap: + f_cost, num_dets, node_id = heapq.heappop(heap) + state = node_data.pop(node_id, None) + if state is None: + continue + nodes_popped += 1 + + max_num_dets = min_num_dets + det_beam + if num_dets > max_num_dets: + continue + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = min_num_dets + det_beam + + errs = state.errs + blocked_errs = state.blocked_errs + dets = state.dets + det_counts = state.det_counts + g_cost = state.g_cost + + if self.verbose_search: + print( + f"len(heap)={len(heap)} nodes_pushed={nodes_pushed} nodes_popped={nodes_popped} " + f"num_dets={num_dets} max_num_dets={max_num_dets} f={f_cost:.6f} g={g_cost:.6f}" + ) + + if num_dets == 0: + return DecodeResult( + success=True, + errs=errs, + residual_dets=dets, + cost=g_cost, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + heuristic_calls=self.heuristic_calls, + lp_calls=self.lp_calls, + elapsed_seconds=time.perf_counter() - start_time, + ) + + min_det = int(np.flatnonzero(dets)[0]) + prefix_blocked_errs = blocked_errs.copy() + + for ei in self.d2e[min_det]: + ei = int(ei) + prefix_blocked_errs[ei] = True + + if errs[ei] or blocked_errs[ei]: + continue + + child_errs = errs.copy() + child_errs[ei] = True + child_blocked_errs = prefix_blocked_errs.copy() + child_dets = dets.copy() + child_det_counts = det_counts.copy() + + for d in self.edets[ei]: + d = int(d) + if child_dets[d]: + child_dets[d] = False + for oei in self.d2e[d]: + child_det_counts[int(oei)] -= 1 + else: + child_dets[d] = True + for oei in self.d2e[d]: + child_det_counts[int(oei)] += 1 + + child_num_dets = int(child_dets.sum()) + if child_num_dets > max_num_dets: + continue + + child_g = g_cost + float(self.ecosts[ei]) + child_h = self.heuristic_cost( + child_errs, + child_blocked_errs, + child_dets, + child_det_counts, + ) + if math.isinf(child_h): + continue + + child_id = next_node_id + next_node_id += 1 + node_data[child_id] = SearchState( + errs=child_errs, + blocked_errs=child_blocked_errs, + dets=child_dets, + det_counts=child_det_counts, + g_cost=child_g, + ) + heapq.heappush(heap, (child_g + child_h, child_num_dets, child_id)) + nodes_pushed += 1 + + return DecodeResult( + success=False, + errs=np.zeros(self.num_errors, dtype=bool), + residual_dets=np.array(shot_dets, dtype=bool, copy=True), + cost=INF, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + heuristic_calls=self.heuristic_calls, + lp_calls=self.lp_calls, + elapsed_seconds=time.perf_counter() - start_time, + ) + + def cost_from_errs(self, errs: np.ndarray) -> float: + return float(self.ecosts[errs].sum()) + + def observables_from_errs(self, errs: np.ndarray) -> np.ndarray: + parity: Dict[int, bool] = {} + for ei in np.flatnonzero(errs): + for obs in self.eobs[int(ei)]: + obs = int(obs) + parity[obs] = not parity.get(obs, False) + return np.array(sorted(obs for obs, bit in parity.items() if bit), dtype=np.int32) + + def detectors_from_errs(self, errs: np.ndarray) -> np.ndarray: + dets = np.zeros(self.num_detectors, dtype=bool) + for ei in np.flatnonzero(errs): + for d in self.edets[int(ei)]: + dets[int(d)] ^= True + return dets + + +def merged_errors_from_dem(dem) -> List[ErrorRecord]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + + for error in dem.flattened(): + if error.type != "error": + continue + + probability = float(error.args_copy()[0]) + if probability <= 0: + continue + if probability > 0.5: + raise ValueError( + f"Expected flattened error probabilities in (0, 0.5], got {probability}." + ) + + detectors: set[int] = set() + observables: set[int] = set() + for target in error.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected target type: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + + key = (tuple(sorted(detectors)), tuple(sorted(observables))) + p_old = errors_by_symptom.get(key) + if p_old is None: + p_new = probability + else: + # Two independent identical symptoms combine by XORing their parity. + p_new = p_old * (1.0 - probability) + (1.0 - p_old) * probability + errors_by_symptom[key] = p_new + + merged: List[ErrorRecord] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + merged.append( + ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +def sample_detections_and_observables(circuit, num_shots: int, seed: int) -> Tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets_packed, obs_packed = sampler.sample( + shots=num_shots, + separate_observables=True, + bit_packed=True, + ) + dets_unpacked = np.unpackbits( + dets_packed, + bitorder="little", + axis=1, + count=circuit.num_detectors, + ) + obs_unpacked = np.unpackbits( + obs_packed, + bitorder="little", + axis=1, + count=circuit.num_observables, + ) + return dets_unpacked.astype(bool), obs_unpacked.astype(bool) + + +def parse_det_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "infinity", "none"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("det-beam must be non-negative or 'inf'.") + return float(value) + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Prototype A* decoder using the plain detector-wise heuristic or the " + "optimal singleton detector heuristic." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a .stim circuit file.") + parser.add_argument( + "--shot", + type=int, + default=0, + help="Shot index to decode after sampling --sample-num-shots shots (default: 0).", + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=100, + help="Number of shots to sample before selecting --shot (default: 100).", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Stim sampler seed (default: 27123839530).", + ) + parser.add_argument( + "--det-beam", + type=parse_det_beam, + default=INF, + help="Beam cutoff on the number of residual detections; use 'inf' for none (default: inf).", + ) + parser.add_argument( + "--opt-singleton-detcost", + action="store_true", + help=( + "Use the exact optimal singleton detector-cost lower bound instead of the " + "plain detector-wise lower bound. Requires scipy." + ), + ) + parser.add_argument( + "--respect-blocked-errors-in-heuristic", + action="store_true", + help=( + "Exclude precedence-blocked errors from the heuristic. By default the script " + "preserves the original prototype's behavior and only excludes already-activated errors." + ), + ) + parser.add_argument( + "--show-detections", + action="store_true", + help="Print the selected shot's detection events before decoding.", + ) + parser.add_argument( + "--show-error-indices", + action="store_true", + help="Print the decoded merged-error indices.", + ) + parser.add_argument( + "--verbose-search", + action="store_true", + help="Print one line per expanded node during A* search.", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + if args.shot >= args.sample_num_shots: + parser.error("--shot must be smaller than --sample-num-shots.") + + try: + import stim + except ImportError as exc: # pragma: no cover - depends on runtime environment. + raise SystemExit("This script requires the 'stim' package to be installed.") from exc + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + errors = merged_errors_from_dem(dem) + + dets_unpacked, obs_unpacked = sample_detections_and_observables( + circuit, + num_shots=args.sample_num_shots, + seed=args.seed, + ) + shot_dets = dets_unpacked[args.shot] + shot_obs = obs_unpacked[args.shot] + + if args.show_detections: + active_dets = np.flatnonzero(shot_dets) + print("detections:", " ".join(f"D{d}" for d in active_dets)) + + decoder = AStarPrototypeDecoder( + errors, + dem.num_detectors, + use_opt_singleton_detcost=args.opt_singleton_detcost, + respect_blocked_errors_in_heuristic=args.respect_blocked_errors_in_heuristic, + verbose_search=args.verbose_search, + ) + result = decoder.decode(shot_dets, det_beam=args.det_beam) + + print(f"heuristic: {decoder.heuristic_name}") + print(f"shot: {args.shot} / {args.sample_num_shots}") + print(f"success: {result.success}") + print(f"nodes_pushed: {result.nodes_pushed}") + print(f"nodes_popped: {result.nodes_popped}") + print(f"heuristic_calls: {result.heuristic_calls}") + print(f"lp_calls: {result.lp_calls}") + print(f"elapsed_seconds: {result.elapsed_seconds:.6f}") + + if not result.success: + print("decode failed") + return 1 + + decoded_err_indices = np.flatnonzero(result.errs) + if args.show_error_indices: + print("decoded_error_indices:", " ".join(map(str, decoded_err_indices.tolist()))) + + reproduced_dets = decoder.detectors_from_errs(result.errs) + if not np.array_equal(reproduced_dets, shot_dets): + raise AssertionError("Decoded errors do not reproduce the sampled detection events.") + + reproduced_cost = decoder.cost_from_errs(result.errs) + predicted_obs = decoder.observables_from_errs(result.errs) + actual_obs = np.flatnonzero(shot_obs) + + print(f"num_decoded_errors: {int(result.errs.sum())}") + print(f"decoded_cost: {reproduced_cost:.12f}") + print("predicted_observables:", " ".join(f"L{o}" for o in predicted_obs.tolist())) + print("sampled_observables:", " ".join(f"L{o}" for o in actual_obs.tolist())) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/astar_prototype_incremental_greedy.py b/src/py/astar/astar_prototype_incremental_greedy.py new file mode 100644 index 0000000..4dc6a2d --- /dev/null +++ b/src/py/astar/astar_prototype_incremental_greedy.py @@ -0,0 +1,969 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder with incremental greedy singleton heuristics. + +Heuristic modes: + --heuristic plain exact plain detcost via incremental support updates + --heuristic asc-deg exact ascending-degree saturation heuristic + --heuristic plain-sweep exact plain+one-sweep saturation heuristic + --heuristic best-of-two max(asc-deg, plain-sweep) + +All four heuristics are maintained incrementally: + * the deduplicated active-support dictionary W(T) is updated from parent to + child using only errors touching flipped detectors; + * heuristic values are recomputed only on the union of touched connected + components of the active-support hypergraph; + * untouched components inherit their detector prices exactly. + +This stays inside the singleton-family lower-bound framework, but avoids any LP +solves while still being much tighter than basic detcost in practice. +""" + +from __future__ import annotations + +import argparse +import heapq +import math +import time +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, FrozenSet, Iterable, List, Optional, Sequence, Set, Tuple + +import numpy as np +import stim + +INF = math.inf + +SupportKey = Tuple[int, ...] + + +@dataclass(frozen=True) +class MergedError: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class DecoderData: + num_detectors: int + num_observables: int + errors: List[MergedError] + detector_to_errors: List[np.ndarray] + error_costs: np.ndarray + error_detectors: List[np.ndarray] + error_observables: List[np.ndarray] + + +@dataclass +class SupportState: + support_to_errors: Dict[SupportKey, FrozenSet[int]] + support_to_weight: Dict[SupportKey, float] + detector_to_supports: Dict[int, FrozenSet[SupportKey]] + + +@dataclass +class HeuristicCache: + support_state: SupportState + h_value: float + y_plain: Optional[np.ndarray] = None + y_asc: Optional[np.ndarray] = None + y_sweep: Optional[np.ndarray] = None + + +@dataclass +class SearchState: + activated_errors: Tuple[int, ...] + errs: np.ndarray + blocked_errors: np.ndarray + active_detectors: np.ndarray + path_cost: float + heuristic_cache: HeuristicCache + + +@dataclass +class DecodeStats: + num_pq_pushed: int + num_nodes_popped: int + max_queue_size: int + heuristic_evaluations: int + support_build_calls: int + support_build_seconds: float + support_update_calls: int + support_update_seconds: float + component_recompute_calls: int + component_recompute_seconds: float + incremental_children: int + changed_supports_total: int + touched_detectors_total: int + elapsed_seconds: float + heuristic_name: str + + +@dataclass +class DecodeResult: + activated_errors: Tuple[int, ...] + path_cost: float + stats: DecodeStats + + +class UnionFind: + def __init__(self, size: int) -> None: + self.parent = list(range(size)) + self.rank = [0] * size + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra = self.find(a) + rb = self.find(b) + if ra == rb: + return + if self.rank[ra] < self.rank[rb]: + self.parent[ra] = rb + elif self.rank[ra] > self.rank[rb]: + self.parent[rb] = ra + else: + self.parent[rb] = ra + self.rank[ra] += 1 + + +class IncrementalGreedyHeuristic: + def __init__( + self, + data: DecoderData, + *, + mode: str, + ) -> None: + valid_modes = {"plain", "asc-deg", "plain-sweep", "best-of-two"} + if mode not in valid_modes: + raise ValueError(f"Unknown heuristic mode: {mode!r}") + self.data = data + self.mode = mode + self.reset_stats() + + def reset_stats(self) -> None: + self.heuristic_evaluations = 0 + self.support_build_calls = 0 + self.support_build_seconds = 0.0 + self.support_update_calls = 0 + self.support_update_seconds = 0.0 + self.component_recompute_calls = 0 + self.component_recompute_seconds = 0.0 + self.incremental_children = 0 + self.changed_supports_total = 0 + self.touched_detectors_total = 0 + + @property + def heuristic_name(self) -> str: + return f"{self.mode}-incremental" + + def _active_support(self, active_detectors: np.ndarray, error_index: int) -> Optional[SupportKey]: + support = tuple(int(d) for d in self.data.error_detectors[error_index] if active_detectors[int(d)]) + return support if support else None + + def _build_support_state_from_scratch( + self, + errs: np.ndarray, + blocked_errors: np.ndarray, + active_detectors: np.ndarray, + ) -> SupportState: + t0 = time.perf_counter() + self.support_build_calls += 1 + + support_to_errors_mut: Dict[SupportKey, Set[int]] = {} + for error_index in range(len(self.data.errors)): + if errs[error_index] or blocked_errors[error_index]: + continue + support = self._active_support(active_detectors, error_index) + if support is None: + continue + bucket = support_to_errors_mut.setdefault(support, set()) + bucket.add(error_index) + + support_to_errors: Dict[SupportKey, FrozenSet[int]] = {} + support_to_weight: Dict[SupportKey, float] = {} + detector_to_supports_mut: Dict[int, Set[SupportKey]] = defaultdict(set) + for support, bucket in support_to_errors_mut.items(): + frozen = frozenset(bucket) + support_to_errors[support] = frozen + support_to_weight[support] = float(min(self.data.error_costs[ei] for ei in frozen)) + for detector in support: + detector_to_supports_mut[detector].add(support) + detector_to_supports = { + detector: frozenset(supports) + for detector, supports in detector_to_supports_mut.items() + if supports + } + + self.support_build_seconds += time.perf_counter() - t0 + return SupportState( + support_to_errors=support_to_errors, + support_to_weight=support_to_weight, + detector_to_supports=detector_to_supports, + ) + + def _update_support_state_incremental( + self, + parent_support_state: SupportState, + parent_errs: np.ndarray, + child_errs: np.ndarray, + parent_blocked: np.ndarray, + child_blocked: np.ndarray, + parent_active_detectors: np.ndarray, + child_active_detectors: np.ndarray, + flipped_detectors: np.ndarray, + ) -> Tuple[SupportState, Set[SupportKey], Set[int]]: + t0 = time.perf_counter() + self.support_update_calls += 1 + + affected_errors: Set[int] = set() + for detector in flipped_detectors: + for error_index in self.data.detector_to_errors[int(detector)]: + affected_errors.add(int(error_index)) + + child_support_to_errors = dict(parent_support_state.support_to_errors) + child_support_to_weight = dict(parent_support_state.support_to_weight) + touched_buckets: Dict[SupportKey, Set[int]] = {} + + def get_touched_bucket(support: SupportKey) -> Set[int]: + bucket = touched_buckets.get(support) + if bucket is None: + bucket = set(parent_support_state.support_to_errors.get(support, frozenset())) + touched_buckets[support] = bucket + return bucket + + for error_index in affected_errors: + old_available = (not parent_errs[error_index]) and (not parent_blocked[error_index]) + new_available = (not child_errs[error_index]) and (not child_blocked[error_index]) + old_support = self._active_support(parent_active_detectors, error_index) if old_available else None + new_support = self._active_support(child_active_detectors, error_index) if new_available else None + if old_support == new_support: + continue + if old_support is not None: + get_touched_bucket(old_support).discard(error_index) + if new_support is not None: + get_touched_bucket(new_support).add(error_index) + + changed_supports: Set[SupportKey] = set() + touched_detectors: Set[int] = set() + + child_detector_to_supports = dict(parent_support_state.detector_to_supports) + touched_detector_sets: Dict[int, Set[SupportKey]] = {} + + for support, bucket in touched_buckets.items(): + old_bucket = parent_support_state.support_to_errors.get(support, frozenset()) + old_present = support in parent_support_state.support_to_weight + new_present = bool(bucket) + + if new_present: + frozen_bucket = frozenset(bucket) + child_support_to_errors[support] = frozen_bucket + new_weight = float(min(self.data.error_costs[ei] for ei in frozen_bucket)) + child_support_to_weight[support] = new_weight + if (not old_present) or frozen_bucket != old_bucket or abs(new_weight - parent_support_state.support_to_weight.get(support, 0.0)) > 1e-12: + changed_supports.add(support) + else: + child_support_to_errors.pop(support, None) + if old_present: + child_support_to_weight.pop(support, None) + changed_supports.add(support) + + if old_present != new_present: + for detector in support: + detector_bucket = touched_detector_sets.get(detector) + if detector_bucket is None: + detector_bucket = set(parent_support_state.detector_to_supports.get(detector, frozenset())) + touched_detector_sets[detector] = detector_bucket + if new_present: + detector_bucket.add(support) + else: + detector_bucket.discard(support) + + for support in changed_supports: + touched_detectors.update(support) + + for detector, supports in touched_detector_sets.items(): + if supports: + child_detector_to_supports[detector] = frozenset(supports) + else: + child_detector_to_supports.pop(detector, None) + + self.incremental_children += 1 + self.changed_supports_total += len(changed_supports) + self.touched_detectors_total += len(touched_detectors) + self.support_update_seconds += time.perf_counter() - t0 + + return ( + SupportState( + support_to_errors=child_support_to_errors, + support_to_weight=child_support_to_weight, + detector_to_supports=child_detector_to_supports, + ), + changed_supports, + touched_detectors, + ) + + def _component_from_seed_detectors( + self, + support_state: SupportState, + seed_detectors: Iterable[int], + active_detectors: np.ndarray, + ) -> Tuple[Set[int], Set[SupportKey]]: + seen_detectors: Set[int] = set() + seen_supports: Set[SupportKey] = set() + stack = [int(d) for d in seed_detectors if active_detectors[int(d)] and int(d) in support_state.detector_to_supports] + + while stack: + detector = stack.pop() + if detector in seen_detectors: + continue + seen_detectors.add(detector) + for support in support_state.detector_to_supports.get(detector, frozenset()): + if support in seen_supports: + continue + seen_supports.add(support) + for other_detector in support: + if active_detectors[other_detector] and other_detector not in seen_detectors: + stack.append(other_detector) + return seen_detectors, seen_supports + + def _all_components(self, support_state: SupportState) -> List[Tuple[Set[int], Set[SupportKey]]]: + components: List[Tuple[Set[int], Set[SupportKey]]] = [] + seen_detectors: Set[int] = set() + for detector in sorted(support_state.detector_to_supports): + if detector in seen_detectors: + continue + dets, supports = self._component_from_seed_detectors( + support_state=support_state, + seed_detectors=[detector], + active_detectors=np.ones(self.data.num_detectors, dtype=bool), + ) + seen_detectors.update(dets) + components.append((dets, supports)) + return components + + def _component_incidence( + self, + component_detectors: Set[int], + component_supports: Set[SupportKey], + support_state: SupportState, + ) -> Dict[int, List[SupportKey]]: + component_supports_set = set(component_supports) + incidence: Dict[int, List[SupportKey]] = {} + for detector in component_detectors: + local_supports = [ + support + for support in support_state.detector_to_supports.get(detector, frozenset()) + if support in component_supports_set + ] + incidence[detector] = local_supports + return incidence + + def _compute_plain_component( + self, + component_detectors: Set[int], + component_supports: Set[SupportKey], + support_state: SupportState, + ) -> Dict[int, float]: + incidence = self._component_incidence(component_detectors, component_supports, support_state) + y: Dict[int, float] = {} + for detector in component_detectors: + best = INF + for support in incidence[detector]: + candidate = support_state.support_to_weight[support] / len(support) + if candidate < best: + best = candidate + if math.isinf(best): + raise RuntimeError("Detector in active support component has no incident support.") + y[detector] = best + return y + + def _compute_asc_component( + self, + component_detectors: Set[int], + component_supports: Set[SupportKey], + support_state: SupportState, + ) -> Dict[int, float]: + incidence = self._component_incidence(component_detectors, component_supports, support_state) + order = sorted(component_detectors, key=lambda d: (len(incidence[d]), d)) + slacks = {support: float(support_state.support_to_weight[support]) for support in component_supports} + y: Dict[int, float] = {} + for detector in order: + value = min(slacks[support] for support in incidence[detector]) + y[detector] = value + for support in incidence[detector]: + slacks[support] -= value + return y + + def _compute_plain_sweep_component( + self, + component_detectors: Set[int], + component_supports: Set[SupportKey], + support_state: SupportState, + ) -> Dict[int, float]: + incidence = self._component_incidence(component_detectors, component_supports, support_state) + y = self._compute_plain_component(component_detectors, component_supports, support_state) + slacks = { + support: float(support_state.support_to_weight[support]) - sum(y[detector] for detector in support) + for support in component_supports + } + order = sorted(component_detectors, key=lambda d: (-y[d], d)) + for detector in order: + delta = min(slacks[support] for support in incidence[detector]) + y[detector] += delta + for support in incidence[detector]: + slacks[support] -= delta + return y + + def _build_cache_from_support_state(self, support_state: SupportState) -> HeuristicCache: + t0 = time.perf_counter() + self.heuristic_evaluations += 1 + self.component_recompute_calls += 1 + + y_plain = np.zeros(self.data.num_detectors, dtype=np.float64) if self.mode == "plain" else None + y_asc = np.zeros(self.data.num_detectors, dtype=np.float64) if self.mode in {"asc-deg", "best-of-two"} else None + y_sweep = np.zeros(self.data.num_detectors, dtype=np.float64) if self.mode in {"plain-sweep", "best-of-two"} else None + + for component_detectors, component_supports in self._all_components(support_state): + if self.mode == "plain": + comp = self._compute_plain_component(component_detectors, component_supports, support_state) + for detector, value in comp.items(): + y_plain[detector] = value + elif self.mode == "asc-deg": + comp = self._compute_asc_component(component_detectors, component_supports, support_state) + for detector, value in comp.items(): + y_asc[detector] = value + elif self.mode == "plain-sweep": + comp = self._compute_plain_sweep_component(component_detectors, component_supports, support_state) + for detector, value in comp.items(): + y_sweep[detector] = value + elif self.mode == "best-of-two": + comp_asc = self._compute_asc_component(component_detectors, component_supports, support_state) + comp_sweep = self._compute_plain_sweep_component(component_detectors, component_supports, support_state) + for detector, value in comp_asc.items(): + y_asc[detector] = value + for detector, value in comp_sweep.items(): + y_sweep[detector] = value + else: + raise AssertionError("unreachable") + + if self.mode == "plain": + h_value = float(y_plain.sum()) + elif self.mode == "asc-deg": + h_value = float(y_asc.sum()) + elif self.mode == "plain-sweep": + h_value = float(y_sweep.sum()) + else: + h_value = float(max(y_asc.sum(), y_sweep.sum())) + + self.component_recompute_seconds += time.perf_counter() - t0 + return HeuristicCache( + support_state=support_state, + h_value=h_value, + y_plain=y_plain, + y_asc=y_asc, + y_sweep=y_sweep, + ) + + def _incremental_child_cache( + self, + parent_cache: HeuristicCache, + child_support_state: SupportState, + touched_detectors: Set[int], + child_active_detectors: np.ndarray, + flipped_detectors: np.ndarray, + ) -> HeuristicCache: + t0 = time.perf_counter() + self.heuristic_evaluations += 1 + self.component_recompute_calls += 1 + + touched_component_detectors, touched_component_supports = self._component_from_seed_detectors( + support_state=child_support_state, + seed_detectors=touched_detectors, + active_detectors=child_active_detectors, + ) + + y_plain = None if parent_cache.y_plain is None else parent_cache.y_plain.copy() + y_asc = None if parent_cache.y_asc is None else parent_cache.y_asc.copy() + y_sweep = None if parent_cache.y_sweep is None else parent_cache.y_sweep.copy() + + for detector in flipped_detectors: + detector = int(detector) + if not child_active_detectors[detector]: + if y_plain is not None: + y_plain[detector] = 0.0 + if y_asc is not None: + y_asc[detector] = 0.0 + if y_sweep is not None: + y_sweep[detector] = 0.0 + + for detector in touched_component_detectors: + if y_plain is not None: + y_plain[detector] = 0.0 + if y_asc is not None: + y_asc[detector] = 0.0 + if y_sweep is not None: + y_sweep[detector] = 0.0 + + if touched_component_detectors: + if self.mode == "plain": + comp = self._compute_plain_component(touched_component_detectors, touched_component_supports, child_support_state) + for detector, value in comp.items(): + y_plain[detector] = value + elif self.mode == "asc-deg": + comp = self._compute_asc_component(touched_component_detectors, touched_component_supports, child_support_state) + for detector, value in comp.items(): + y_asc[detector] = value + elif self.mode == "plain-sweep": + comp = self._compute_plain_sweep_component(touched_component_detectors, touched_component_supports, child_support_state) + for detector, value in comp.items(): + y_sweep[detector] = value + elif self.mode == "best-of-two": + comp_asc = self._compute_asc_component(touched_component_detectors, touched_component_supports, child_support_state) + comp_sweep = self._compute_plain_sweep_component(touched_component_detectors, touched_component_supports, child_support_state) + for detector, value in comp_asc.items(): + y_asc[detector] = value + for detector, value in comp_sweep.items(): + y_sweep[detector] = value + else: + raise AssertionError("unreachable") + + if self.mode == "plain": + h_value = float(y_plain.sum()) + elif self.mode == "asc-deg": + h_value = float(y_asc.sum()) + elif self.mode == "plain-sweep": + h_value = float(y_sweep.sum()) + else: + h_value = float(max(y_asc.sum(), y_sweep.sum())) + + self.component_recompute_seconds += time.perf_counter() - t0 + return HeuristicCache( + support_state=child_support_state, + h_value=h_value, + y_plain=y_plain, + y_asc=y_asc, + y_sweep=y_sweep, + ) + + def build_root_cache( + self, + errs: np.ndarray, + blocked_errors: np.ndarray, + active_detectors: np.ndarray, + ) -> HeuristicCache: + support_state = self._build_support_state_from_scratch(errs, blocked_errors, active_detectors) + return self._build_cache_from_support_state(support_state) + + def build_child_cache( + self, + parent_state: SearchState, + child_errs: np.ndarray, + child_blocked_errors: np.ndarray, + child_active_detectors: np.ndarray, + flipped_detectors: np.ndarray, + ) -> HeuristicCache: + child_support_state, _changed_supports, touched_detectors = self._update_support_state_incremental( + parent_support_state=parent_state.heuristic_cache.support_state, + parent_errs=parent_state.errs, + child_errs=child_errs, + parent_blocked=parent_state.blocked_errors, + child_blocked=child_blocked_errors, + parent_active_detectors=parent_state.active_detectors, + child_active_detectors=child_active_detectors, + flipped_detectors=flipped_detectors, + ) + return self._incremental_child_cache( + parent_cache=parent_state.heuristic_cache, + child_support_state=child_support_state, + touched_detectors=touched_detectors, + child_active_detectors=child_active_detectors, + flipped_detectors=flipped_detectors, + ) + + +def xor_probability(p0: float, p1: float) -> float: + return p0 * (1.0 - p1) + (1.0 - p0) * p1 + + +def iter_dem_errors(dem: stim.DetectorErrorModel) -> Iterable[MergedError]: + for instruction in dem.flattened(): + if instruction.type != "error": + continue + probability = float(instruction.args_copy()[0]) + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError("This prototype assumes DEM probabilities in (0, 0.5).") + detectors: Set[int] = set() + observables: Set[int] = set() + for target in instruction.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected DEM target: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + yield MergedError( + probability=probability, + likelihood_cost=float(-math.log(probability / (1.0 - probability))), + detectors=tuple(sorted(detectors)), + observables=tuple(sorted(observables)), + ) + + +def merged_errors(dem: stim.DetectorErrorModel) -> List[MergedError]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + for error in iter_dem_errors(dem): + key = (error.detectors, error.observables) + previous = errors_by_symptom.get(key) + if previous is None: + errors_by_symptom[key] = error.probability + else: + errors_by_symptom[key] = xor_probability(previous, error.probability) + + merged: List[MergedError] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError("Merged error has probability >= 0.5.") + merged.append( + MergedError( + probability=probability, + likelihood_cost=float(-math.log(probability / (1.0 - probability))), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +def build_decoder_data(dem: stim.DetectorErrorModel, *, merge_errors_in_dem: bool = True) -> DecoderData: + errors = merged_errors(dem) if merge_errors_in_dem else list(iter_dem_errors(dem)) + detector_to_errors_lists: List[List[int]] = [[] for _ in range(dem.num_detectors)] + for error_index, error in enumerate(errors): + for detector in error.detectors: + detector_to_errors_lists[detector].append(error_index) + return DecoderData( + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + errors=errors, + detector_to_errors=[np.asarray(v, dtype=np.int32) for v in detector_to_errors_lists], + error_costs=np.asarray([err.likelihood_cost for err in errors], dtype=np.float64), + error_detectors=[np.asarray(err.detectors, dtype=np.int32) for err in errors], + error_observables=[np.asarray(err.observables, dtype=np.int32) for err in errors], + ) + + +def unpack_bit_packed_rows(bits: np.ndarray, count: int) -> np.ndarray: + return np.unpackbits(bits, bitorder="little", axis=1, count=count).astype(bool, copy=False) + + +def detectors_from_solution(data: DecoderData, activated_errors: Sequence[int]) -> np.ndarray: + detectors = np.zeros(data.num_detectors, dtype=bool) + for error_index in activated_errors: + for detector in data.error_detectors[error_index]: + detectors[int(detector)] ^= True + return detectors + + +def observables_from_solution(data: DecoderData, activated_errors: Sequence[int]) -> np.ndarray: + observables = np.zeros(data.num_observables, dtype=bool) + for error_index in activated_errors: + for observable in data.error_observables[error_index]: + observables[int(observable)] ^= True + return observables + + +def parse_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "+inf", "infinity", "+infinity", "none"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("beam must be non-negative or 'inf'") + return float(value) + + +def format_indices(indices: Iterable[int], prefix: str) -> str: + items = list(indices) + if not items: + return "(none)" + return " ".join(f"{prefix}{i}" for i in items) + + +def decode( + data: DecoderData, + detections: np.ndarray, + *, + det_beam: float, + heuristic: IncrementalGreedyHeuristic, + verbose_search: bool = False, +) -> DecodeResult: + start_time = time.perf_counter() + heuristic.reset_stats() + + root_dets = np.asarray(detections, dtype=bool).copy() + root_errs = np.zeros(len(data.errors), dtype=bool) + root_blocked = np.zeros(len(data.errors), dtype=bool) + root_cache = heuristic.build_root_cache(root_errs, root_blocked, root_dets) + root_state = SearchState( + activated_errors=(), + errs=root_errs, + blocked_errors=root_blocked, + active_detectors=root_dets, + path_cost=0.0, + heuristic_cache=root_cache, + ) + + heap: List[Tuple[float, int, int]] = [(root_state.path_cost + root_state.heuristic_cache.h_value, int(root_dets.sum()), 0)] + node_data: Dict[int, SearchState] = {0: root_state} + next_node_id = 1 + + num_pq_pushed = 1 + num_nodes_popped = 0 + max_queue_size = 1 + min_num_dets = int(root_dets.sum()) + + while heap: + max_queue_size = max(max_queue_size, len(heap)) + f_cost, num_dets, node_id = heapq.heappop(heap) + state = node_data.pop(node_id, None) + if state is None: + continue + num_nodes_popped += 1 + + max_num_dets = INF if det_beam == INF else min_num_dets + det_beam + if num_dets > max_num_dets: + continue + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = INF if det_beam == INF else min_num_dets + det_beam + + if verbose_search: + print( + f"len(heap)={len(heap)} nodes_pushed={num_pq_pushed} nodes_popped={num_nodes_popped} " + f"active_dets={num_dets} beam_max={max_num_dets} depth={len(state.activated_errors)} " + f"f={f_cost:.12g} g={state.path_cost:.12g} h={state.heuristic_cache.h_value:.12g}" + ) + + if num_dets == 0: + elapsed = time.perf_counter() - start_time + return DecodeResult( + activated_errors=state.activated_errors, + path_cost=state.path_cost, + stats=DecodeStats( + num_pq_pushed=num_pq_pushed, + num_nodes_popped=num_nodes_popped, + max_queue_size=max_queue_size, + heuristic_evaluations=heuristic.heuristic_evaluations, + support_build_calls=heuristic.support_build_calls, + support_build_seconds=heuristic.support_build_seconds, + support_update_calls=heuristic.support_update_calls, + support_update_seconds=heuristic.support_update_seconds, + component_recompute_calls=heuristic.component_recompute_calls, + component_recompute_seconds=heuristic.component_recompute_seconds, + incremental_children=heuristic.incremental_children, + changed_supports_total=heuristic.changed_supports_total, + touched_detectors_total=heuristic.touched_detectors_total, + elapsed_seconds=elapsed, + heuristic_name=heuristic.heuristic_name, + ), + ) + + min_detector = int(np.flatnonzero(state.active_detectors)[0]) + blocked_prefix = state.blocked_errors.copy() + + children_generated = 0 + children_beam_pruned = 0 + for error_index in data.detector_to_errors[min_detector]: + error_index = int(error_index) + blocked_prefix[error_index] = True + if state.errs[error_index] or state.blocked_errors[error_index]: + continue + + child_errs = state.errs.copy() + child_errs[error_index] = True + child_blocked = blocked_prefix.copy() + child_active_detectors = state.active_detectors.copy() + flipped_detectors = data.error_detectors[error_index] + for detector in flipped_detectors: + child_active_detectors[int(detector)] = ~child_active_detectors[int(detector)] + + child_num_dets = int(child_active_detectors.sum()) + if child_num_dets > max_num_dets: + children_beam_pruned += 1 + continue + + child_cache = heuristic.build_child_cache( + parent_state=state, + child_errs=child_errs, + child_blocked_errors=child_blocked, + child_active_detectors=child_active_detectors, + flipped_detectors=flipped_detectors, + ) + child_state = SearchState( + activated_errors=state.activated_errors + (error_index,), + errs=child_errs, + blocked_errors=child_blocked, + active_detectors=child_active_detectors, + path_cost=state.path_cost + float(data.error_costs[error_index]), + heuristic_cache=child_cache, + ) + child_id = next_node_id + next_node_id += 1 + node_data[child_id] = child_state + heapq.heappush(heap, (child_state.path_cost + child_cache.h_value, child_num_dets, child_id)) + num_pq_pushed += 1 + children_generated += 1 + + if verbose_search: + print( + f" expanded node={node_id} children_generated={children_generated} " + f"beam_pruned={children_beam_pruned} support_updates={heuristic.support_update_calls}" + ) + + raise RuntimeError("Decoding failed to find any completion.") + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Stim-compatible A* prototype with incrementally maintained greedy singleton-family lower bounds." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a Stim circuit file.") + parser.add_argument("--shot", type=int, default=0, help="Zero-based sampled shot index to decode.") + parser.add_argument("--sample-num-shots", type=int, default=100, help="Number of shots to sample before selecting --shot.") + parser.add_argument("--seed", type=int, default=27123839530, help="Seed passed to stim.compile_detector_sampler(...).sample(...).") + parser.add_argument("--det-beam", type=parse_beam, default=INF, help="Beam cutoff on the residual detector count. Use an integer or 'inf'.") + parser.add_argument( + "--heuristic", + choices=["plain", "asc-deg", "plain-sweep", "best-of-two"], + default="plain-sweep", + help="Incremental singleton-family heuristic to use.", + ) + parser.add_argument( + "--merge-errors", + action=argparse.BooleanOptionalAction, + default=True, + help="Merge indistinguishable DEM errors before decoding (default: enabled).", + ) + parser.add_argument( + "--show-shot-detectors", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the sampled shot's active detector IDs before decoding.", + ) + parser.add_argument( + "--show-error-indices", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the activated error indices in the final decoding.", + ) + parser.add_argument("--verbose-search", action="store_true", help="Print per-node search diagnostics.") + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + data = build_decoder_data(dem, merge_errors_in_dem=args.merge_errors) + + dets_packed, obs_packed = circuit.compile_detector_sampler(seed=args.seed).sample( + shots=args.sample_num_shots, + separate_observables=True, + bit_packed=True, + ) + detections = unpack_bit_packed_rows(dets_packed, count=dem.num_detectors) + observables = unpack_bit_packed_rows(obs_packed, count=dem.num_observables) + + if args.shot >= detections.shape[0]: + parser.error(f"--shot={args.shot} is out of range for {detections.shape[0]} sampled shots.") + + shot_detections = detections[args.shot] + shot_observables = observables[args.shot] if observables.size else np.zeros(0, dtype=bool) + + heuristic = IncrementalGreedyHeuristic(data, mode=args.heuristic) + + print(f"circuit = {args.circuit}") + print(f"heuristic = {heuristic.heuristic_name}") + print(f"shot = {args.shot}") + print(f"sample_num_shots = {args.sample_num_shots}") + print(f"num_detectors = {data.num_detectors}") + print(f"num_observables = {data.num_observables}") + print(f"num_errors = {len(data.errors)}") + print(f"beam = {args.det_beam}") + if args.show_shot_detectors: + print(f"shot_detectors = {format_indices(np.flatnonzero(shot_detections), 'D')}") + + result = decode( + data=data, + detections=shot_detections, + det_beam=args.det_beam, + heuristic=heuristic, + verbose_search=args.verbose_search, + ) + + predicted_observables = observables_from_solution(data, result.activated_errors) + reproduced_detectors = detectors_from_solution(data, result.activated_errors) + if not np.array_equal(reproduced_detectors, shot_detections): + raise AssertionError("Decoded error set does not reproduce the shot's syndrome.") + + print(f"solution_size = {len(result.activated_errors)}") + print(f"solution_cost = {result.path_cost:.12g}") + if args.show_error_indices: + print(f"activated_errors = {format_indices(result.activated_errors, 'E')}") + print(f"predicted_observables = {format_indices(np.flatnonzero(predicted_observables), 'L')}") + print(f"sample_observables = {format_indices(np.flatnonzero(shot_observables), 'L')}") + print(f"observables_match = {bool(np.array_equal(predicted_observables, shot_observables))}") + print(f"num_pq_pushed = {result.stats.num_pq_pushed}") + print(f"num_nodes_popped = {result.stats.num_nodes_popped}") + print(f"max_queue_size = {result.stats.max_queue_size}") + print(f"heuristic_evaluations = {result.stats.heuristic_evaluations}") + print(f"support_build_calls = {result.stats.support_build_calls}") + print(f"support_build_seconds = {result.stats.support_build_seconds:.6f}") + print(f"support_update_calls = {result.stats.support_update_calls}") + print(f"support_update_seconds = {result.stats.support_update_seconds:.6f}") + print(f"component_recompute_calls = {result.stats.component_recompute_calls}") + print(f"component_recompute_seconds = {result.stats.component_recompute_seconds:.6f}") + print(f"incremental_children = {result.stats.incremental_children}") + mean_changed_supports = ( + result.stats.changed_supports_total / result.stats.incremental_children + if result.stats.incremental_children else 0.0 + ) + mean_touched_detectors = ( + result.stats.touched_detectors_total / result.stats.incremental_children + if result.stats.incremental_children else 0.0 + ) + print(f"mean_changed_supports = {mean_changed_supports:.6f}") + print(f"mean_touched_detectors = {mean_touched_detectors:.6f}") + print(f"elapsed_seconds = {result.stats.elapsed_seconds:.6f}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/astar_prototype_local_blast.py b/src/py/astar/astar_prototype_local_blast.py new file mode 100644 index 0000000..3524e83 --- /dev/null +++ b/src/py/astar/astar_prototype_local_blast.py @@ -0,0 +1,1203 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder with projected singleton-LP refinement. + +The default heuristic matches the original prototype's plain detector-wise +heuristic. Passing --opt-singleton-detcost enables a lazy version of the exact +optimal singleton detector lower bound: + + * a node is first inserted with a cheap lower bound; + * when the node is popped, an LP-based refinement is optionally run; + * if the refined LP value raises the node's key, the node is reinserted; + * expanded nodes project their refined LP solution onto each child. + +By default the refinement is a full singleton LP solve on pop. Two experimental +modes are also available: + + * --local-lp-component: only reoptimize the active support component(s) + touched by the flipped error used to create the node. + * --local-lp-radius R: only reoptimize detector prices within an R-hop + neighborhood of the changed region, freezing all other prices at their + projected values. + +Both restricted modes remain admissible because they optimize over a subset of +variables while keeping the rest fixed at a feasible point. +""" + +from __future__ import annotations + +import argparse +import heapq +import math +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple + +import numpy as np +import stim +from scipy.optimize import linprog +from scipy.sparse import csr_matrix + +INF = float("inf") +HEURISTIC_EPS = 1e-9 +FEAS_EPS = 1e-9 + + +@dataclass(frozen=True) +class ErrorRecord: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class SupportSystem: + active_dets: np.ndarray + support_to_weight: Dict[Tuple[int, ...], float] + covered_all: bool + + +@dataclass +class OptSingletonLPResult: + value: float + y_full: np.ndarray + num_active_dets: int + num_supports: int + num_free_vars: int + seed_frontier_size: int + mode: str + + +@dataclass +class SearchState: + errs: np.ndarray + blocked_errs: np.ndarray + dets: np.ndarray + det_counts: np.ndarray + g_cost: float + h_cost: float + h_source: str + ready_to_expand: bool + lp_y: Optional[np.ndarray] = None + seed_detectors: Optional[np.ndarray] = None + + +@dataclass +class DecodeResult: + success: bool + errs: np.ndarray + residual_dets: np.ndarray + cost: float + nodes_pushed: int + nodes_popped: int + heuristic_calls: int + plain_heuristic_calls: int + projection_heuristic_calls: int + refinement_calls: int + lp_calls: int + full_lp_calls: int + component_lp_calls: int + radius_lp_calls: int + lp_reinserts: int + projected_nodes_generated: int + projected_nodes_refined: int + projected_nodes_unrefined_at_finish: int + total_lp_refinement_gain: float + max_lp_refinement_gain: float + elapsed_seconds: float + + +class AStarPrototypeDecoder: + def __init__( + self, + errors: Sequence[ErrorRecord], + num_detectors: int, + *, + use_opt_singleton_detcost: bool = False, + respect_blocked_errors_in_heuristic: bool = False, + verbose_search: bool = False, + local_lp_radius: Optional[int] = None, + local_lp_component: bool = False, + ) -> None: + self.errors = list(errors) + self.num_detectors = int(num_detectors) + self.num_errors = len(self.errors) + self.use_opt_singleton_detcost = use_opt_singleton_detcost + self.respect_blocked_errors_in_heuristic = respect_blocked_errors_in_heuristic + self.verbose_search = verbose_search + self.local_lp_radius = local_lp_radius + self.local_lp_component = local_lp_component + + if self.local_lp_radius is not None and self.local_lp_radius < 0: + raise ValueError("local_lp_radius must be non-negative or None") + if self.local_lp_radius is not None and self.local_lp_component: + raise ValueError("Choose at most one of local_lp_radius and local_lp_component") + + self.ecosts = np.array([err.likelihood_cost for err in self.errors], dtype=np.float64) + self.edets: List[np.ndarray] = [ + np.array(err.detectors, dtype=np.int32) for err in self.errors + ] + self.eobs: List[np.ndarray] = [ + np.array(err.observables, dtype=np.int32) for err in self.errors + ] + + d2e_lists: List[List[int]] = [[] for _ in range(self.num_detectors)] + for ei, dets in enumerate(self.edets): + for d in dets: + d2e_lists[int(d)].append(ei) + self.d2e: List[np.ndarray] = [np.array(v, dtype=np.int32) for v in d2e_lists] + + self.reset_stats() + + def reset_stats(self) -> None: + self.heuristic_calls = 0 + self.plain_heuristic_calls = 0 + self.projection_heuristic_calls = 0 + self.refinement_calls = 0 + self.lp_calls = 0 + self.full_lp_calls = 0 + self.component_lp_calls = 0 + self.radius_lp_calls = 0 + self.lp_reinserts = 0 + self.projected_nodes_generated = 0 + self.projected_nodes_refined = 0 + self.total_lp_refinement_gain = 0.0 + self.max_lp_refinement_gain = 0.0 + + @property + def heuristic_name(self) -> str: + if not self.use_opt_singleton_detcost: + return "plain-detcost" + if self.local_lp_component: + return "opt-singleton-detcost-lazy-projection-component" + if self.local_lp_radius is not None: + return f"opt-singleton-detcost-lazy-projection-radius{self.local_lp_radius}" + return "opt-singleton-detcost-lazy-projection-full" + + def _available_errors(self, errs: np.ndarray, blocked_errs: np.ndarray) -> np.ndarray: + available = ~errs + if self.respect_blocked_errors_in_heuristic: + available &= ~blocked_errs + return available + + def _plain_detcost_heuristic( + self, + available_errs: np.ndarray, + dets: np.ndarray, + det_counts: np.ndarray, + ) -> float: + self.heuristic_calls += 1 + self.plain_heuristic_calls += 1 + + total = 0.0 + for d in np.flatnonzero(dets): + best = INF + for ei in self.d2e[int(d)]: + ei = int(ei) + if not available_errs[ei]: + continue + count = int(det_counts[ei]) + assert count > 0 + value = self.ecosts[ei] / count + if value < best: + best = value + if math.isinf(best): + return INF + total += best + return total + + def _build_support_system( + self, + available_errs: np.ndarray, + dets: np.ndarray, + det_counts: np.ndarray, + ) -> SupportSystem: + active_dets = np.flatnonzero(dets) + if active_dets.size == 0: + return SupportSystem( + active_dets=active_dets, + support_to_weight={}, + covered_all=True, + ) + + covered = np.zeros(self.num_detectors, dtype=bool) + support_to_weight: Dict[Tuple[int, ...], float] = {} + + for ei in np.flatnonzero(available_errs): + ei = int(ei) + if int(det_counts[ei]) == 0: + continue + support = tuple(int(d) for d in self.edets[ei] if dets[int(d)]) + if not support: + continue + for d in support: + covered[d] = True + weight = float(self.ecosts[ei]) + old = support_to_weight.get(support) + if old is None or weight < old: + support_to_weight[support] = weight + + covered_all = bool(np.all(covered[active_dets])) + return SupportSystem( + active_dets=active_dets, + support_to_weight=support_to_weight, + covered_all=covered_all, + ) + + def _build_active_neighbors( + self, + supports: Iterable[Tuple[int, ...]], + active_dets: np.ndarray, + ) -> Dict[int, Set[int]]: + neighbors: Dict[int, Set[int]] = {int(d): set() for d in active_dets.tolist()} + for support in supports: + if len(support) <= 1: + continue + support_list = list(support) + for i, d in enumerate(support_list): + nbrs = neighbors[int(d)] + for od in support_list[:i]: + nbrs.add(int(od)) + for od in support_list[i + 1 :]: + nbrs.add(int(od)) + return neighbors + + def _seed_frontier_from_flipped_detectors( + self, + flipped_detectors: np.ndarray, + dets: np.ndarray, + ) -> np.ndarray: + if flipped_detectors is None or len(flipped_detectors) == 0: + return np.zeros(0, dtype=np.int32) + + frontier: Set[int] = set() + seen_errors: Set[int] = set() + for fd in flipped_detectors: + for ei in self.d2e[int(fd)]: + ei = int(ei) + if ei in seen_errors: + continue + seen_errors.add(ei) + for d in self.edets[ei]: + d = int(d) + if dets[d]: + frontier.add(d) + if not frontier: + return np.zeros(0, dtype=np.int32) + return np.array(sorted(frontier), dtype=np.int32) + + def _free_dets_from_scope( + self, + support_system: SupportSystem, + dets: np.ndarray, + flipped_detectors: Optional[np.ndarray], + ) -> Tuple[np.ndarray, int]: + active_dets = support_system.active_dets + if active_dets.size == 0: + return np.zeros(0, dtype=np.int32), 0 + if flipped_detectors is None or len(flipped_detectors) == 0: + return np.array(active_dets, copy=True), int(active_dets.size) + + seed_frontier = self._seed_frontier_from_flipped_detectors(flipped_detectors, dets) + seed_frontier_size = int(seed_frontier.size) + if seed_frontier_size == 0: + return np.zeros(0, dtype=np.int32), 0 + + neighbors = self._build_active_neighbors(support_system.support_to_weight.keys(), active_dets) + if self.local_lp_component: + radius_limit: Optional[int] = None + elif self.local_lp_radius is not None: + radius_limit = int(self.local_lp_radius) + else: + return np.array(active_dets, copy=True), seed_frontier_size + + visited: Set[int] = set(int(d) for d in seed_frontier.tolist()) + frontier: List[int] = [int(d) for d in seed_frontier.tolist()] + depth = 0 + + while frontier and (radius_limit is None or depth < radius_limit): + next_frontier: List[int] = [] + for d in frontier: + for od in neighbors.get(d, ()): # detector adjacency in active support graph + if od in visited: + continue + visited.add(od) + next_frontier.append(od) + frontier = next_frontier + depth += 1 + + if not visited: + return np.zeros(0, dtype=np.int32), seed_frontier_size + return np.array(sorted(visited), dtype=np.int32), seed_frontier_size + + def _solve_lp_with_fixed_outside( + self, + support_system: SupportSystem, + base_y: np.ndarray, + free_dets: np.ndarray, + *, + mode: str, + seed_frontier_size: int, + ) -> OptSingletonLPResult: + self.heuristic_calls += 1 + self.refinement_calls += 1 + self.lp_calls += 1 + if mode == "full": + self.full_lp_calls += 1 + elif mode == "component": + self.component_lp_calls += 1 + elif mode.startswith("radius"): + self.radius_lp_calls += 1 + else: + raise ValueError(f"Unknown LP mode: {mode}") + + active_dets = support_system.active_dets + num_active_dets = int(active_dets.size) + num_supports = int(len(support_system.support_to_weight)) + + if num_active_dets == 0: + return OptSingletonLPResult( + value=0.0, + y_full=np.zeros(self.num_detectors, dtype=np.float64), + num_active_dets=0, + num_supports=0, + num_free_vars=0, + seed_frontier_size=seed_frontier_size, + mode=mode, + ) + + if not support_system.covered_all: + return OptSingletonLPResult( + value=INF, + y_full=np.array(base_y, copy=True), + num_active_dets=num_active_dets, + num_supports=num_supports, + num_free_vars=int(free_dets.size), + seed_frontier_size=seed_frontier_size, + mode=mode, + ) + + y_full = np.array(base_y, copy=True) + free_dets = np.array(sorted(set(int(d) for d in free_dets.tolist() if bool(base_y.shape[0] > d))), dtype=np.int32) + if free_dets.size == 0: + return OptSingletonLPResult( + value=float(y_full[active_dets].sum()), + y_full=y_full, + num_active_dets=num_active_dets, + num_supports=num_supports, + num_free_vars=0, + seed_frontier_size=seed_frontier_size, + mode=mode, + ) + + free_set = set(int(d) for d in free_dets.tolist()) + det_to_var = {int(d): i for i, d in enumerate(free_dets.tolist())} + + row_indices: List[int] = [] + col_indices: List[int] = [] + data: List[float] = [] + rhs: List[float] = [] + row = 0 + + for support, weight in support_system.support_to_weight.items(): + fixed_sum = 0.0 + free_support_vars: List[int] = [] + for d in support: + d = int(d) + if d in free_set: + free_support_vars.append(det_to_var[d]) + else: + fixed_sum += float(y_full[d]) + remaining = float(weight) - fixed_sum + if remaining < -FEAS_EPS: + raise AssertionError( + f"Base y is infeasible in restricted LP: remaining={remaining} mode={mode}" + ) + remaining = max(0.0, remaining) + if not free_support_vars: + continue + rhs.append(remaining) + row_indices.extend([row] * len(free_support_vars)) + col_indices.extend(free_support_vars) + data.extend([1.0] * len(free_support_vars)) + row += 1 + + if row == 0: + return OptSingletonLPResult( + value=float(y_full[active_dets].sum()), + y_full=y_full, + num_active_dets=num_active_dets, + num_supports=num_supports, + num_free_vars=int(free_dets.size), + seed_frontier_size=seed_frontier_size, + mode=mode, + ) + + a_ub = csr_matrix( + (data, (row_indices, col_indices)), + shape=(row, int(free_dets.size)), + dtype=np.float64, + ) + + result = linprog( + c=-np.ones(int(free_dets.size), dtype=np.float64), + A_ub=a_ub, + b_ub=np.array(rhs, dtype=np.float64), + bounds=[(0.0, None)] * int(free_dets.size), + method="highs", + ) + if result.status == 0: + y_full[free_dets] = np.asarray(result.x, dtype=np.float64) + return OptSingletonLPResult( + value=float(y_full[active_dets].sum()), + y_full=y_full, + num_active_dets=num_active_dets, + num_supports=num_supports, + num_free_vars=int(free_dets.size), + seed_frontier_size=seed_frontier_size, + mode=mode, + ) + if result.status in {2, 3}: # infeasible or unbounded + return OptSingletonLPResult( + value=INF, + y_full=y_full, + num_active_dets=num_active_dets, + num_supports=num_supports, + num_free_vars=int(free_dets.size), + seed_frontier_size=seed_frontier_size, + mode=mode, + ) + raise RuntimeError(f"linprog failed with status={result.status}: {result.message}") + + def _plain_heuristic_from_state(self, state: SearchState) -> float: + available = self._available_errors(state.errs, state.blocked_errs) + return self._plain_detcost_heuristic(available, state.dets, state.det_counts) + + def _project_child_solution_and_heuristic( + self, + parent_state: SearchState, + flipped_detectors: np.ndarray, + ) -> Tuple[np.ndarray, float]: + if parent_state.lp_y is None: + raise AssertionError("Expected parent LP solution before projecting to children.") + + self.heuristic_calls += 1 + self.projection_heuristic_calls += 1 + + child_y = np.array(parent_state.lp_y, copy=True) + for d in flipped_detectors: + d = int(d) + if parent_state.dets[d]: + child_y[d] = 0.0 + value = float(parent_state.h_cost) + for d in flipped_detectors: + d = int(d) + if parent_state.dets[d]: + value -= float(parent_state.lp_y[d]) + if value < -HEURISTIC_EPS: + raise AssertionError(f"Projected heuristic became negative: {value}") + return child_y, max(0.0, value) + + def _refine_scope_name(self) -> str: + if self.local_lp_component: + return "component" + if self.local_lp_radius is not None: + return f"radius{self.local_lp_radius}" + return "full" + + def _refine_node_lp( + self, + state: SearchState, + ) -> OptSingletonLPResult: + available = self._available_errors(state.errs, state.blocked_errs) + support_system = self._build_support_system(available, state.dets, state.det_counts) + + # Root or any node without a projected parent solution falls back to a full LP. + if state.lp_y is None or state.seed_detectors is None: + base_y = np.zeros(self.num_detectors, dtype=np.float64) + return self._solve_lp_with_fixed_outside( + support_system, + base_y, + support_system.active_dets, + mode="full", + seed_frontier_size=0, + ) + + if self.local_lp_component or self.local_lp_radius is not None: + free_dets, seed_frontier_size = self._free_dets_from_scope( + support_system, + state.dets, + state.seed_detectors, + ) + mode = self._refine_scope_name() + return self._solve_lp_with_fixed_outside( + support_system, + state.lp_y, + free_dets, + mode=mode, + seed_frontier_size=seed_frontier_size, + ) + + return self._solve_lp_with_fixed_outside( + support_system, + np.zeros(self.num_detectors, dtype=np.float64), + support_system.active_dets, + mode="full", + seed_frontier_size=0, + ) + + def _maybe_refine_node_with_lp( + self, + node_id: int, + state: SearchState, + num_dets: int, + ) -> Tuple[SearchState, Optional[Tuple[float, int]], Optional[Dict[str, float | str]]]: + if not self.use_opt_singleton_detcost or state.ready_to_expand: + return state, None, None + + prev_h = state.h_cost + prev_source = state.h_source + lp_result = self._refine_node_lp(state) + refined_h = lp_result.value + + if math.isinf(refined_h): + refine_info = { + "approx_h": prev_h, + "exact_h": refined_h, + "delta": INF, + "num_vars": float(lp_result.num_active_dets), + "num_supports": float(lp_result.num_supports), + "num_free_vars": float(lp_result.num_free_vars), + "seed_frontier_size": float(lp_result.seed_frontier_size), + "reinserted": 0.0, + "discarded": 1.0, + "mode": lp_result.mode, + } + if prev_source == "projected": + self.projected_nodes_refined += 1 + return state, None, refine_info + + if refined_h + 1e-7 < prev_h: + raise AssertionError( + f"Refined LP lower bound {refined_h} is below stored {prev_source} lower bound {prev_h}." + ) + + delta = refined_h - prev_h + if prev_source == "projected": + self.projected_nodes_refined += 1 + self.total_lp_refinement_gain += delta + self.max_lp_refinement_gain = max(self.max_lp_refinement_gain, delta) + + state.h_cost = refined_h + state.h_source = lp_result.mode + state.ready_to_expand = True + state.lp_y = lp_result.y_full + + should_reinsert = delta > HEURISTIC_EPS + reinsert_entry = (state.g_cost + refined_h, num_dets) if should_reinsert else None + if should_reinsert: + self.lp_reinserts += 1 + + refine_info = { + "approx_h": prev_h, + "exact_h": refined_h, + "delta": delta, + "num_vars": float(lp_result.num_active_dets), + "num_supports": float(lp_result.num_supports), + "num_free_vars": float(lp_result.num_free_vars), + "seed_frontier_size": float(lp_result.seed_frontier_size), + "reinserted": 1.0 if should_reinsert else 0.0, + "discarded": 0.0, + "mode": lp_result.mode, + } + return state, reinsert_entry, refine_info + + def _log_pop( + self, + *, + heap_len: int, + nodes_pushed: int, + nodes_popped: int, + num_dets: int, + max_num_dets: float, + f_cost: float, + state: SearchState, + ) -> None: + if not self.verbose_search: + return + projected_unrefined = self.projected_nodes_generated - self.projected_nodes_refined + print( + f"len(heap)={heap_len} nodes_pushed={nodes_pushed} nodes_popped={nodes_popped} " + f"lp_calls={self.lp_calls} full_lp_calls={self.full_lp_calls} " + f"component_lp_calls={self.component_lp_calls} radius_lp_calls={self.radius_lp_calls} " + f"lp_reinserts={self.lp_reinserts} proj_generated={self.projected_nodes_generated} " + f"proj_refined={self.projected_nodes_refined} proj_unrefined_so_far={projected_unrefined} " + f"num_dets={num_dets} max_num_dets={max_num_dets} f={f_cost:.6f} g={state.g_cost:.6f} " + f"h={state.h_cost:.6f} h_source={state.h_source} ready_to_expand={state.ready_to_expand}" + ) + + def _log_refine(self, node_id: int, info: Dict[str, float | str]) -> None: + if not self.verbose_search: + return + exact_h = float(info["exact_h"]) + exact_text = "INF" if math.isinf(exact_h) else f"{exact_h:.6f}" + delta = float(info["delta"]) + delta_text = "INF" if math.isinf(delta) else f"{delta:.6f}" + print( + f" lp_refine node={node_id} mode={info['mode']} approx_h={float(info['approx_h']):.6f} " + f"refined_h={exact_text} delta={delta_text} vars={int(float(info['num_vars']))} " + f"supports={int(float(info['num_supports']))} free_vars={int(float(info['num_free_vars']))} " + f"seed_frontier={int(float(info['seed_frontier_size']))} " + f"reinserted={bool(info['reinserted'])} discarded={bool(info['discarded'])}" + ) + + def _log_expand( + self, + *, + node_id: int, + children_generated: int, + children_projected: int, + children_beam_pruned: int, + children_infeasible: int, + ) -> None: + if not self.verbose_search: + return + projected_unrefined = self.projected_nodes_generated - self.projected_nodes_refined + print( + f" expanded node={node_id} children_generated={children_generated} " + f"children_projected={children_projected} beam_pruned={children_beam_pruned} " + f"infeasible={children_infeasible} lp_calls={self.lp_calls} full_lp_calls={self.full_lp_calls} " + f"component_lp_calls={self.component_lp_calls} radius_lp_calls={self.radius_lp_calls} " + f"proj_unrefined_so_far={projected_unrefined}" + ) + + def _result( + self, + *, + success: bool, + errs: np.ndarray, + residual_dets: np.ndarray, + cost: float, + nodes_pushed: int, + nodes_popped: int, + start_time: float, + ) -> DecodeResult: + return DecodeResult( + success=success, + errs=errs, + residual_dets=residual_dets, + cost=cost, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + full_lp_calls=self.full_lp_calls, + component_lp_calls=self.component_lp_calls, + radius_lp_calls=self.radius_lp_calls, + lp_reinserts=self.lp_reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=( + self.projected_nodes_generated - self.projected_nodes_refined + ), + total_lp_refinement_gain=self.total_lp_refinement_gain, + max_lp_refinement_gain=self.max_lp_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + def decode(self, shot_dets: np.ndarray, det_beam: float = INF) -> DecodeResult: + start_time = time.perf_counter() + self.reset_stats() + + dets0 = np.array(shot_dets, dtype=bool, copy=True) + errs0 = np.zeros(self.num_errors, dtype=bool) + blocked0 = np.zeros(self.num_errors, dtype=bool) + det_counts0 = np.zeros(self.num_errors, dtype=np.uint16) + for d in np.flatnonzero(dets0): + for ei in self.d2e[int(d)]: + det_counts0[int(ei)] += 1 + + root_state = SearchState( + errs=errs0, + blocked_errs=blocked0, + dets=dets0, + det_counts=det_counts0, + g_cost=0.0, + h_cost=0.0, + h_source="plain", + ready_to_expand=not self.use_opt_singleton_detcost, + lp_y=None, + seed_detectors=None, + ) + root_state.h_cost = self._plain_heuristic_from_state(root_state) + if math.isinf(root_state.h_cost): + return self._result( + success=False, + errs=errs0, + residual_dets=dets0, + cost=INF, + nodes_pushed=1, + nodes_popped=0, + start_time=start_time, + ) + + next_node_id = 1 + heap: List[Tuple[float, int, int]] = [ + (root_state.g_cost + root_state.h_cost, int(dets0.sum()), 0) + ] + node_data: Dict[int, SearchState] = {0: root_state} + + nodes_pushed = 1 + nodes_popped = 0 + min_num_dets = int(dets0.sum()) + + while heap: + f_cost, num_dets, node_id = heapq.heappop(heap) + state = node_data.pop(node_id, None) + if state is None: + continue + nodes_popped += 1 + + max_num_dets = min_num_dets + det_beam + if num_dets > max_num_dets: + continue + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = min_num_dets + det_beam + + self._log_pop( + heap_len=len(heap), + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + num_dets=num_dets, + max_num_dets=max_num_dets, + f_cost=f_cost, + state=state, + ) + + if num_dets == 0: + return self._result( + success=True, + errs=state.errs, + residual_dets=state.dets, + cost=state.g_cost, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + start_time=start_time, + ) + + state, reinsert_entry, refine_info = self._maybe_refine_node_with_lp( + node_id=node_id, + state=state, + num_dets=num_dets, + ) + if refine_info is not None: + self._log_refine(node_id, refine_info) + if bool(refine_info["discarded"]): + continue + if reinsert_entry is not None: + node_data[node_id] = state + heapq.heappush(heap, (reinsert_entry[0], reinsert_entry[1], node_id)) + continue + + if self.use_opt_singleton_detcost and not state.ready_to_expand: + raise AssertionError("Opt-singleton mode should only expand refined nodes.") + + min_det = int(np.flatnonzero(state.dets)[0]) + prefix_blocked_errs = state.blocked_errs.copy() + + children_generated = 0 + children_beam_pruned = 0 + children_infeasible = 0 + children_projected = 0 + + for ei in self.d2e[min_det]: + ei = int(ei) + prefix_blocked_errs[ei] = True + + if state.errs[ei] or state.blocked_errs[ei]: + continue + + child_errs = state.errs.copy() + child_errs[ei] = True + child_blocked_errs = prefix_blocked_errs.copy() + child_dets = state.dets.copy() + child_det_counts = state.det_counts.copy() + + for d in self.edets[ei]: + d = int(d) + if child_dets[d]: + child_dets[d] = False + for oei in self.d2e[d]: + child_det_counts[int(oei)] -= 1 + else: + child_dets[d] = True + for oei in self.d2e[d]: + child_det_counts[int(oei)] += 1 + + child_num_dets = int(child_dets.sum()) + if child_num_dets > max_num_dets: + children_beam_pruned += 1 + continue + + child_g = state.g_cost + float(self.ecosts[ei]) + + if self.use_opt_singleton_detcost: + if state.lp_y is None: + raise AssertionError("Expected a refined parent LP solution before projection.") + child_lp_y, child_h = self._project_child_solution_and_heuristic( + state, + self.edets[ei], + ) + child_h_source = "projected" + child_ready_to_expand = False + child_seed_detectors = np.array(self.edets[ei], copy=True) + self.projected_nodes_generated += 1 + children_projected += 1 + else: + child_tmp_state = SearchState( + errs=child_errs, + blocked_errs=child_blocked_errs, + dets=child_dets, + det_counts=child_det_counts, + g_cost=child_g, + h_cost=0.0, + h_source="plain", + ready_to_expand=True, + lp_y=None, + seed_detectors=None, + ) + child_h = self._plain_heuristic_from_state(child_tmp_state) + child_h_source = "plain" + child_ready_to_expand = True + child_lp_y = None + child_seed_detectors = None + if math.isinf(child_h): + children_infeasible += 1 + continue + + child_id = next_node_id + next_node_id += 1 + node_data[child_id] = SearchState( + errs=child_errs, + blocked_errs=child_blocked_errs, + dets=child_dets, + det_counts=child_det_counts, + g_cost=child_g, + h_cost=child_h, + h_source=child_h_source, + ready_to_expand=child_ready_to_expand, + lp_y=child_lp_y, + seed_detectors=child_seed_detectors, + ) + heapq.heappush(heap, (child_g + child_h, child_num_dets, child_id)) + nodes_pushed += 1 + children_generated += 1 + + self._log_expand( + node_id=node_id, + children_generated=children_generated, + children_projected=children_projected, + children_beam_pruned=children_beam_pruned, + children_infeasible=children_infeasible, + ) + + return self._result( + success=False, + errs=np.zeros(self.num_errors, dtype=bool), + residual_dets=np.array(shot_dets, dtype=bool, copy=True), + cost=INF, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + start_time=start_time, + ) + + def cost_from_errs(self, errs: np.ndarray) -> float: + return float(self.ecosts[errs].sum()) + + def observables_from_errs(self, errs: np.ndarray) -> np.ndarray: + parity: Dict[int, bool] = {} + for ei in np.flatnonzero(errs): + for obs in self.eobs[int(ei)]: + obs = int(obs) + parity[obs] = not parity.get(obs, False) + return np.array(sorted(obs for obs, bit in parity.items() if bit), dtype=np.int32) + + def detectors_from_errs(self, errs: np.ndarray) -> np.ndarray: + dets = np.zeros(self.num_detectors, dtype=bool) + for ei in np.flatnonzero(errs): + for d in self.edets[int(ei)]: + dets[int(d)] ^= True + return dets + + +def merged_errors_from_dem(dem) -> List[ErrorRecord]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + + for error in dem.flattened(): + if error.type != "error": + continue + + probability = float(error.args_copy()[0]) + if probability <= 0: + continue + if probability > 0.5: + raise ValueError( + f"Expected flattened error probabilities in (0, 0.5], got {probability}." + ) + + detectors: set[int] = set() + observables: set[int] = set() + for target in error.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected target type: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + + key = (tuple(sorted(detectors)), tuple(sorted(observables))) + p_old = errors_by_symptom.get(key) + if p_old is None: + p_new = probability + else: + p_new = p_old * (1.0 - probability) + (1.0 - p_old) * probability + errors_by_symptom[key] = p_new + + merged: List[ErrorRecord] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + merged.append( + ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +def sample_detections_and_observables(circuit, num_shots: int, seed: int) -> Tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets_packed, obs_packed = sampler.sample( + shots=num_shots, + separate_observables=True, + bit_packed=True, + ) + dets_unpacked = np.unpackbits( + dets_packed, + bitorder="little", + axis=1, + count=circuit.num_detectors, + ) + obs_unpacked = np.unpackbits( + obs_packed, + bitorder="little", + axis=1, + count=circuit.num_observables, + ) + return dets_unpacked.astype(bool), obs_unpacked.astype(bool) + + +def parse_det_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "infinity", "none"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("det-beam must be non-negative or 'inf'.") + return float(value) + + +def parse_nonnegative_int(text: str) -> int: + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("Expected a non-negative integer.") + return value + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Prototype A* decoder using the plain detector-wise heuristic or a lazy " + "projected version of the optimal singleton detector heuristic." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a .stim circuit file.") + parser.add_argument( + "--shot", + type=int, + default=0, + help="Shot index to decode after sampling --sample-num-shots shots (default: 0).", + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=100, + help="Number of shots to sample before selecting --shot (default: 100).", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Stim sampler seed (default: 27123839530).", + ) + parser.add_argument( + "--det-beam", + type=parse_det_beam, + default=INF, + help="Beam cutoff on the number of residual detections; use 'inf' for none (default: inf).", + ) + parser.add_argument( + "--opt-singleton-detcost", + action="store_true", + help=( + "Use lazy refinement of the optimal singleton detector-cost lower bound. " + "Nodes are seeded with projected LP prices from their parent and refined " + "when popped." + ), + ) + parser.add_argument( + "--local-lp-component", + action="store_true", + help=( + "Instead of a full LP on pop, only reoptimize the active support component(s) " + "touched by the flipped error that created the node." + ), + ) + parser.add_argument( + "--local-lp-radius", + type=parse_nonnegative_int, + default=None, + help=( + "Instead of a full LP on pop, only reoptimize detector prices within this " + "radius in the active support graph around the changed region." + ), + ) + parser.add_argument( + "--respect-blocked-errors-in-heuristic", + action="store_true", + help=( + "Exclude precedence-blocked errors from the heuristic. By default the script " + "preserves the original prototype's behavior and only excludes already-activated errors." + ), + ) + parser.add_argument( + "--show-detections", + action="store_true", + help="Print the selected shot's detection events before decoding.", + ) + parser.add_argument( + "--show-error-indices", + action="store_true", + help="Print the decoded merged-error indices.", + ) + parser.add_argument( + "--verbose-search", + action="store_true", + help="Print detailed search, LP-refinement, projection, and locality statistics during A* search.", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + if args.shot >= args.sample_num_shots: + parser.error("--shot must be smaller than --sample-num-shots.") + if args.local_lp_component and args.local_lp_radius is not None: + parser.error("Choose at most one of --local-lp-component and --local-lp-radius.") + if (args.local_lp_component or args.local_lp_radius is not None) and not args.opt_singleton_detcost: + parser.error("Local LP refinement flags require --opt-singleton-detcost.") + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + errors = merged_errors_from_dem(dem) + + dets_unpacked, obs_unpacked = sample_detections_and_observables( + circuit, + num_shots=args.sample_num_shots, + seed=args.seed, + ) + shot_dets = dets_unpacked[args.shot] + shot_obs = obs_unpacked[args.shot] + + if args.show_detections: + active_dets = np.flatnonzero(shot_dets) + print("detections:", " ".join(f"D{d}" for d in active_dets)) + + decoder = AStarPrototypeDecoder( + errors, + dem.num_detectors, + use_opt_singleton_detcost=args.opt_singleton_detcost, + respect_blocked_errors_in_heuristic=args.respect_blocked_errors_in_heuristic, + verbose_search=args.verbose_search, + local_lp_radius=args.local_lp_radius, + local_lp_component=args.local_lp_component, + ) + result = decoder.decode(shot_dets, det_beam=args.det_beam) + + print(f"heuristic: {decoder.heuristic_name}") + print(f"shot: {args.shot} / {args.sample_num_shots}") + print(f"success: {result.success}") + print(f"nodes_pushed: {result.nodes_pushed}") + print(f"nodes_popped: {result.nodes_popped}") + print(f"heuristic_calls: {result.heuristic_calls}") + print(f"plain_heuristic_calls: {result.plain_heuristic_calls}") + print(f"projection_heuristic_calls: {result.projection_heuristic_calls}") + print(f"refinement_calls: {result.refinement_calls}") + print(f"lp_calls: {result.lp_calls}") + print(f"full_lp_calls: {result.full_lp_calls}") + print(f"component_lp_calls: {result.component_lp_calls}") + print(f"radius_lp_calls: {result.radius_lp_calls}") + print(f"lp_reinserts: {result.lp_reinserts}") + print(f"projected_nodes_generated: {result.projected_nodes_generated}") + print(f"projected_nodes_refined: {result.projected_nodes_refined}") + print(f"projected_nodes_unrefined_at_finish: {result.projected_nodes_unrefined_at_finish}") + print(f"total_lp_refinement_gain: {result.total_lp_refinement_gain:.6f}") + print(f"max_lp_refinement_gain: {result.max_lp_refinement_gain:.6f}") + print(f"elapsed_seconds: {result.elapsed_seconds:.6f}") + + if not result.success: + print("decode failed") + return 1 + + decoded_err_indices = np.flatnonzero(result.errs) + if args.show_error_indices: + print("decoded_error_indices:", " ".join(map(str, decoded_err_indices.tolist()))) + + reproduced_dets = decoder.detectors_from_errs(result.errs) + if not np.array_equal(reproduced_dets, shot_dets): + raise AssertionError("Decoded errors do not reproduce the sampled detection events.") + + reproduced_cost = decoder.cost_from_errs(result.errs) + predicted_obs = decoder.observables_from_errs(result.errs) + actual_obs = np.flatnonzero(shot_obs) + + print(f"num_decoded_errors: {int(result.errs.sum())}") + print(f"decoded_cost: {reproduced_cost:.12f}") + print("predicted_observables:", " ".join(f"L{o}" for o in predicted_obs.tolist())) + print("sampled_observables:", " ".join(f"L{o}" for o in actual_obs.tolist())) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/astar_prototype_projected.py b/src/py/astar/astar_prototype_projected.py new file mode 100644 index 0000000..5843c3a --- /dev/null +++ b/src/py/astar/astar_prototype_projected.py @@ -0,0 +1,882 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder with lazy singleton-LP refinement. + +The default heuristic matches the original prototype's plain detector-wise +heuristic. Passing --opt-singleton-detcost enables a lazy version of the exact +optimal singleton detector lower bound: + + * a node is first inserted with a cheap lower bound; + * when the node is popped, the exact singleton LP is solved; + * if the exact LP value raises the node's key, the node is reinserted; + * expanded nodes project their exact LP solution onto each child to seed a + much tighter cheap first-pass lower bound than plain detcost. + +This keeps the prototype pedagogical while making the expensive LP solves much +more selective. +""" + +from __future__ import annotations + +import argparse +import heapq +import math +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Tuple + +import numpy as np +import stim +from scipy.optimize import linprog +from scipy.sparse import csr_matrix + +INF = float("inf") +HEURISTIC_EPS = 1e-9 + + +@dataclass(frozen=True) +class ErrorRecord: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class OptSingletonLPResult: + value: float + y_full: np.ndarray + num_active_dets: int + num_supports: int + + +@dataclass +class SearchState: + errs: np.ndarray + blocked_errs: np.ndarray + dets: np.ndarray + det_counts: np.ndarray + g_cost: float + h_cost: float + h_source: str + exact_refined: bool + lp_y: Optional[np.ndarray] = None + + +@dataclass +class DecodeResult: + success: bool + errs: np.ndarray + residual_dets: np.ndarray + cost: float + nodes_pushed: int + nodes_popped: int + heuristic_calls: int + plain_heuristic_calls: int + projection_heuristic_calls: int + exact_refinement_calls: int + lp_calls: int + lp_reinserts: int + projected_nodes_generated: int + projected_nodes_refined: int + projected_nodes_unrefined_at_finish: int + total_lp_refinement_gain: float + max_lp_refinement_gain: float + elapsed_seconds: float + + +class AStarPrototypeDecoder: + def __init__( + self, + errors: Sequence[ErrorRecord], + num_detectors: int, + *, + use_opt_singleton_detcost: bool = False, + respect_blocked_errors_in_heuristic: bool = False, + verbose_search: bool = False, + ) -> None: + self.errors = list(errors) + self.num_detectors = int(num_detectors) + self.num_errors = len(self.errors) + self.use_opt_singleton_detcost = use_opt_singleton_detcost + self.respect_blocked_errors_in_heuristic = respect_blocked_errors_in_heuristic + self.verbose_search = verbose_search + + self.ecosts = np.array([err.likelihood_cost for err in self.errors], dtype=np.float64) + self.edets: List[np.ndarray] = [ + np.array(err.detectors, dtype=np.int32) for err in self.errors + ] + self.eobs: List[np.ndarray] = [ + np.array(err.observables, dtype=np.int32) for err in self.errors + ] + + d2e_lists: List[List[int]] = [[] for _ in range(self.num_detectors)] + for ei, dets in enumerate(self.edets): + for d in dets: + d2e_lists[int(d)].append(ei) + self.d2e: List[np.ndarray] = [np.array(v, dtype=np.int32) for v in d2e_lists] + + self.reset_stats() + + def reset_stats(self) -> None: + self.heuristic_calls = 0 + self.plain_heuristic_calls = 0 + self.projection_heuristic_calls = 0 + self.exact_refinement_calls = 0 + self.lp_calls = 0 + self.lp_reinserts = 0 + self.projected_nodes_generated = 0 + self.projected_nodes_refined = 0 + self.total_lp_refinement_gain = 0.0 + self.max_lp_refinement_gain = 0.0 + + @property + def heuristic_name(self) -> str: + if self.use_opt_singleton_detcost: + return "opt-singleton-detcost-lazy-projection" + return "plain-detcost" + + def _available_errors(self, errs: np.ndarray, blocked_errs: np.ndarray) -> np.ndarray: + available = ~errs + if self.respect_blocked_errors_in_heuristic: + available &= ~blocked_errs + return available + + def _plain_detcost_heuristic( + self, + available_errs: np.ndarray, + dets: np.ndarray, + det_counts: np.ndarray, + ) -> float: + self.heuristic_calls += 1 + self.plain_heuristic_calls += 1 + + total = 0.0 + for d in np.flatnonzero(dets): + best = INF + for ei in self.d2e[int(d)]: + ei = int(ei) + if not available_errs[ei]: + continue + count = int(det_counts[ei]) + assert count > 0 + value = self.ecosts[ei] / count + if value < best: + best = value + if math.isinf(best): + return INF + total += best + return total + + def _solve_opt_singleton_lp( + self, + available_errs: np.ndarray, + dets: np.ndarray, + det_counts: np.ndarray, + ) -> OptSingletonLPResult: + self.heuristic_calls += 1 + self.exact_refinement_calls += 1 + + active_dets = np.flatnonzero(dets) + if active_dets.size == 0: + return OptSingletonLPResult( + value=0.0, + y_full=np.zeros(self.num_detectors, dtype=np.float64), + num_active_dets=0, + num_supports=0, + ) + + det_to_var = {int(d): i for i, d in enumerate(active_dets.tolist())} + support_to_weight: Dict[Tuple[int, ...], float] = {} + covered = np.zeros(active_dets.size, dtype=bool) + + for ei in np.flatnonzero(available_errs): + ei = int(ei) + if int(det_counts[ei]) == 0: + continue + support = tuple(det_to_var[int(d)] for d in self.edets[ei] if dets[int(d)]) + if not support: + continue + for var in support: + covered[var] = True + weight = float(self.ecosts[ei]) + old = support_to_weight.get(support) + if old is None or weight < old: + support_to_weight[support] = weight + + if not np.all(covered): + return OptSingletonLPResult( + value=INF, + y_full=np.zeros(self.num_detectors, dtype=np.float64), + num_active_dets=int(active_dets.size), + num_supports=len(support_to_weight), + ) + + supports = list(support_to_weight.keys()) + weights = np.array([support_to_weight[s] for s in supports], dtype=np.float64) + num_vars = int(active_dets.size) + + row_indices: List[int] = [] + col_indices: List[int] = [] + data: List[float] = [] + for row, support in enumerate(supports): + row_indices.extend([row] * len(support)) + col_indices.extend(support) + data.extend([1.0] * len(support)) + + a_ub = csr_matrix( + (data, (row_indices, col_indices)), + shape=(len(supports), num_vars), + dtype=np.float64, + ) + + self.lp_calls += 1 + result = linprog( + c=-np.ones(num_vars, dtype=np.float64), + A_ub=a_ub, + b_ub=weights, + bounds=[(0.0, None)] * num_vars, + method="highs", + ) + if result.status == 0: + y_full = np.zeros(self.num_detectors, dtype=np.float64) + y_full[active_dets] = np.asarray(result.x, dtype=np.float64) + return OptSingletonLPResult( + value=max(0.0, float(-result.fun)), + y_full=y_full, + num_active_dets=num_vars, + num_supports=len(supports), + ) + if result.status in {2, 3}: # infeasible or unbounded + return OptSingletonLPResult( + value=INF, + y_full=np.zeros(self.num_detectors, dtype=np.float64), + num_active_dets=num_vars, + num_supports=len(supports), + ) + raise RuntimeError(f"linprog failed with status={result.status}: {result.message}") + + def _plain_heuristic_from_state(self, state: SearchState) -> float: + available = self._available_errors(state.errs, state.blocked_errs) + return self._plain_detcost_heuristic(available, state.dets, state.det_counts) + + def _project_child_heuristic( + self, + parent_state: SearchState, + flipped_detectors: np.ndarray, + ) -> float: + if parent_state.lp_y is None: + raise AssertionError("Expected parent exact LP solution before projecting to children.") + + self.heuristic_calls += 1 + self.projection_heuristic_calls += 1 + + value = parent_state.h_cost + for d in flipped_detectors: + d = int(d) + if parent_state.dets[d]: + value -= float(parent_state.lp_y[d]) + if value < -HEURISTIC_EPS: + raise AssertionError(f"Projected heuristic became negative: {value}") + return max(0.0, value) + + def _maybe_refine_node_with_exact_lp( + self, + node_id: int, + state: SearchState, + num_dets: int, + ) -> Tuple[SearchState, Optional[Tuple[float, int]], Optional[Dict[str, float]]]: + if not self.use_opt_singleton_detcost or state.exact_refined: + return state, None, None + + prev_h = state.h_cost + prev_source = state.h_source + available = self._available_errors(state.errs, state.blocked_errs) + lp_result = self._solve_opt_singleton_lp(available, state.dets, state.det_counts) + exact_h = lp_result.value + + if math.isinf(exact_h): + refine_info = { + "approx_h": prev_h, + "exact_h": exact_h, + "delta": INF, + "num_vars": float(lp_result.num_active_dets), + "num_supports": float(lp_result.num_supports), + "reinserted": 0.0, + "discarded": 1.0, + } + if prev_source == "projected": + self.projected_nodes_refined += 1 + return state, None, refine_info + + if exact_h + 1e-7 < prev_h: + raise AssertionError( + f"Exact LP lower bound {exact_h} is below stored {prev_source} lower bound {prev_h}." + ) + + delta = exact_h - prev_h + if prev_source == "projected": + self.projected_nodes_refined += 1 + self.total_lp_refinement_gain += delta + self.max_lp_refinement_gain = max(self.max_lp_refinement_gain, delta) + + state.h_cost = exact_h + state.h_source = "exact" + state.exact_refined = True + state.lp_y = lp_result.y_full + + should_reinsert = delta > HEURISTIC_EPS + reinsert_entry = (state.g_cost + exact_h, num_dets) if should_reinsert else None + if should_reinsert: + self.lp_reinserts += 1 + + refine_info = { + "approx_h": prev_h, + "exact_h": exact_h, + "delta": delta, + "num_vars": float(lp_result.num_active_dets), + "num_supports": float(lp_result.num_supports), + "reinserted": 1.0 if should_reinsert else 0.0, + "discarded": 0.0, + } + return state, reinsert_entry, refine_info + + def _log_pop( + self, + *, + heap_len: int, + nodes_pushed: int, + nodes_popped: int, + num_dets: int, + max_num_dets: float, + f_cost: float, + state: SearchState, + ) -> None: + if not self.verbose_search: + return + projected_unrefined = self.projected_nodes_generated - self.projected_nodes_refined + print( + f"len(heap)={heap_len} nodes_pushed={nodes_pushed} nodes_popped={nodes_popped} " + f"lp_calls={self.lp_calls} lp_reinserts={self.lp_reinserts} " + f"proj_generated={self.projected_nodes_generated} proj_refined={self.projected_nodes_refined} " + f"proj_unrefined_so_far={projected_unrefined} " + f"num_dets={num_dets} max_num_dets={max_num_dets} f={f_cost:.6f} g={state.g_cost:.6f} " + f"h={state.h_cost:.6f} h_source={state.h_source} exact_refined={state.exact_refined}" + ) + + def _log_refine(self, node_id: int, info: Dict[str, float]) -> None: + if not self.verbose_search: + return + exact_h = info["exact_h"] + exact_text = "INF" if math.isinf(exact_h) else f"{exact_h:.6f}" + delta = info["delta"] + delta_text = "INF" if math.isinf(delta) else f"{delta:.6f}" + print( + f" lp_refine node={node_id} approx_h={info['approx_h']:.6f} exact_h={exact_text} " + f"delta={delta_text} vars={int(info['num_vars'])} supports={int(info['num_supports'])} " + f"reinserted={bool(info['reinserted'])} discarded={bool(info['discarded'])}" + ) + + def _log_expand( + self, + *, + node_id: int, + children_generated: int, + children_projected: int, + children_beam_pruned: int, + children_infeasible: int, + ) -> None: + if not self.verbose_search: + return + projected_unrefined = self.projected_nodes_generated - self.projected_nodes_refined + print( + f" expanded node={node_id} children_generated={children_generated} " + f"children_projected={children_projected} beam_pruned={children_beam_pruned} " + f"infeasible={children_infeasible} lp_calls={self.lp_calls} " + f"proj_unrefined_so_far={projected_unrefined}" + ) + + def _result( + self, + *, + success: bool, + errs: np.ndarray, + residual_dets: np.ndarray, + cost: float, + nodes_pushed: int, + nodes_popped: int, + start_time: float, + ) -> DecodeResult: + return DecodeResult( + success=success, + errs=errs, + residual_dets=residual_dets, + cost=cost, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + exact_refinement_calls=self.exact_refinement_calls, + lp_calls=self.lp_calls, + lp_reinserts=self.lp_reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=( + self.projected_nodes_generated - self.projected_nodes_refined + ), + total_lp_refinement_gain=self.total_lp_refinement_gain, + max_lp_refinement_gain=self.max_lp_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + def decode(self, shot_dets: np.ndarray, det_beam: float = INF) -> DecodeResult: + start_time = time.perf_counter() + self.reset_stats() + + dets0 = np.array(shot_dets, dtype=bool, copy=True) + errs0 = np.zeros(self.num_errors, dtype=bool) + blocked0 = np.zeros(self.num_errors, dtype=bool) + det_counts0 = np.zeros(self.num_errors, dtype=np.uint16) + for d in np.flatnonzero(dets0): + for ei in self.d2e[int(d)]: + det_counts0[int(ei)] += 1 + + root_state = SearchState( + errs=errs0, + blocked_errs=blocked0, + dets=dets0, + det_counts=det_counts0, + g_cost=0.0, + h_cost=0.0, + h_source="plain", + exact_refined=not self.use_opt_singleton_detcost, + lp_y=None, + ) + root_state.h_cost = self._plain_heuristic_from_state(root_state) + if math.isinf(root_state.h_cost): + return self._result( + success=False, + errs=errs0, + residual_dets=dets0, + cost=INF, + nodes_pushed=1, + nodes_popped=0, + start_time=start_time, + ) + + next_node_id = 1 + heap: List[Tuple[float, int, int]] = [ + (root_state.g_cost + root_state.h_cost, int(dets0.sum()), 0) + ] + node_data: Dict[int, SearchState] = {0: root_state} + + nodes_pushed = 1 + nodes_popped = 0 + min_num_dets = int(dets0.sum()) + + while heap: + f_cost, num_dets, node_id = heapq.heappop(heap) + state = node_data.pop(node_id, None) + if state is None: + continue + nodes_popped += 1 + + max_num_dets = min_num_dets + det_beam + if num_dets > max_num_dets: + continue + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = min_num_dets + det_beam + + self._log_pop( + heap_len=len(heap), + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + num_dets=num_dets, + max_num_dets=max_num_dets, + f_cost=f_cost, + state=state, + ) + + if num_dets == 0: + return self._result( + success=True, + errs=state.errs, + residual_dets=state.dets, + cost=state.g_cost, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + start_time=start_time, + ) + + state, reinsert_entry, refine_info = self._maybe_refine_node_with_exact_lp( + node_id=node_id, + state=state, + num_dets=num_dets, + ) + if refine_info is not None: + self._log_refine(node_id, refine_info) + if bool(refine_info["discarded"]): + continue + if reinsert_entry is not None: + node_data[node_id] = state + heapq.heappush(heap, (reinsert_entry[0], reinsert_entry[1], node_id)) + continue + + if self.use_opt_singleton_detcost and not state.exact_refined: + raise AssertionError("Opt-singleton mode should only expand exact-refined nodes.") + + min_det = int(np.flatnonzero(state.dets)[0]) + prefix_blocked_errs = state.blocked_errs.copy() + + children_generated = 0 + children_beam_pruned = 0 + children_infeasible = 0 + children_projected = 0 + + for ei in self.d2e[min_det]: + ei = int(ei) + prefix_blocked_errs[ei] = True + + if state.errs[ei] or state.blocked_errs[ei]: + continue + + child_errs = state.errs.copy() + child_errs[ei] = True + child_blocked_errs = prefix_blocked_errs.copy() + child_dets = state.dets.copy() + child_det_counts = state.det_counts.copy() + + for d in self.edets[ei]: + d = int(d) + if child_dets[d]: + child_dets[d] = False + for oei in self.d2e[d]: + child_det_counts[int(oei)] -= 1 + else: + child_dets[d] = True + for oei in self.d2e[d]: + child_det_counts[int(oei)] += 1 + + child_num_dets = int(child_dets.sum()) + if child_num_dets > max_num_dets: + children_beam_pruned += 1 + continue + + child_g = state.g_cost + float(self.ecosts[ei]) + + if self.use_opt_singleton_detcost: + child_h = self._project_child_heuristic(state, self.edets[ei]) + child_h_source = "projected" + child_exact_refined = False + child_lp_y = None + self.projected_nodes_generated += 1 + children_projected += 1 + else: + child_tmp_state = SearchState( + errs=child_errs, + blocked_errs=child_blocked_errs, + dets=child_dets, + det_counts=child_det_counts, + g_cost=child_g, + h_cost=0.0, + h_source="plain", + exact_refined=False, + lp_y=None, + ) + child_h = self._plain_heuristic_from_state(child_tmp_state) + child_h_source = "plain" + child_exact_refined = True + child_lp_y = None + if math.isinf(child_h): + children_infeasible += 1 + continue + + child_id = next_node_id + next_node_id += 1 + node_data[child_id] = SearchState( + errs=child_errs, + blocked_errs=child_blocked_errs, + dets=child_dets, + det_counts=child_det_counts, + g_cost=child_g, + h_cost=child_h, + h_source=child_h_source, + exact_refined=child_exact_refined, + lp_y=child_lp_y, + ) + heapq.heappush(heap, (child_g + child_h, child_num_dets, child_id)) + nodes_pushed += 1 + children_generated += 1 + + self._log_expand( + node_id=node_id, + children_generated=children_generated, + children_projected=children_projected, + children_beam_pruned=children_beam_pruned, + children_infeasible=children_infeasible, + ) + + return self._result( + success=False, + errs=np.zeros(self.num_errors, dtype=bool), + residual_dets=np.array(shot_dets, dtype=bool, copy=True), + cost=INF, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + start_time=start_time, + ) + + def cost_from_errs(self, errs: np.ndarray) -> float: + return float(self.ecosts[errs].sum()) + + def observables_from_errs(self, errs: np.ndarray) -> np.ndarray: + parity: Dict[int, bool] = {} + for ei in np.flatnonzero(errs): + for obs in self.eobs[int(ei)]: + obs = int(obs) + parity[obs] = not parity.get(obs, False) + return np.array(sorted(obs for obs, bit in parity.items() if bit), dtype=np.int32) + + def detectors_from_errs(self, errs: np.ndarray) -> np.ndarray: + dets = np.zeros(self.num_detectors, dtype=bool) + for ei in np.flatnonzero(errs): + for d in self.edets[int(ei)]: + dets[int(d)] ^= True + return dets + + +def merged_errors_from_dem(dem) -> List[ErrorRecord]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + + for error in dem.flattened(): + if error.type != "error": + continue + + probability = float(error.args_copy()[0]) + if probability <= 0: + continue + if probability > 0.5: + raise ValueError( + f"Expected flattened error probabilities in (0, 0.5], got {probability}." + ) + + detectors: set[int] = set() + observables: set[int] = set() + for target in error.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected target type: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + + key = (tuple(sorted(detectors)), tuple(sorted(observables))) + p_old = errors_by_symptom.get(key) + if p_old is None: + p_new = probability + else: + p_new = p_old * (1.0 - probability) + (1.0 - p_old) * probability + errors_by_symptom[key] = p_new + + merged: List[ErrorRecord] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + merged.append( + ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +def sample_detections_and_observables(circuit, num_shots: int, seed: int) -> Tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets_packed, obs_packed = sampler.sample( + shots=num_shots, + separate_observables=True, + bit_packed=True, + ) + dets_unpacked = np.unpackbits( + dets_packed, + bitorder="little", + axis=1, + count=circuit.num_detectors, + ) + obs_unpacked = np.unpackbits( + obs_packed, + bitorder="little", + axis=1, + count=circuit.num_observables, + ) + return dets_unpacked.astype(bool), obs_unpacked.astype(bool) + + +def parse_det_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "infinity", "none"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("det-beam must be non-negative or 'inf'.") + return float(value) + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Prototype A* decoder using the plain detector-wise heuristic or a lazy " + "projected version of the optimal singleton detector heuristic." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a .stim circuit file.") + parser.add_argument( + "--shot", + type=int, + default=0, + help="Shot index to decode after sampling --sample-num-shots shots (default: 0).", + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=100, + help="Number of shots to sample before selecting --shot (default: 100).", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Stim sampler seed (default: 27123839530).", + ) + parser.add_argument( + "--det-beam", + type=parse_det_beam, + default=INF, + help="Beam cutoff on the number of residual detections; use 'inf' for none (default: inf).", + ) + parser.add_argument( + "--opt-singleton-detcost", + action="store_true", + help=( + "Use lazy refinement of the exact optimal singleton detector-cost lower bound. " + "Nodes are seeded with projected LP prices from their parent and only solved " + "exactly when popped." + ), + ) + parser.add_argument( + "--respect-blocked-errors-in-heuristic", + action="store_true", + help=( + "Exclude precedence-blocked errors from the heuristic. By default the script " + "preserves the original prototype's behavior and only excludes already-activated errors." + ), + ) + parser.add_argument( + "--show-detections", + action="store_true", + help="Print the selected shot's detection events before decoding.", + ) + parser.add_argument( + "--show-error-indices", + action="store_true", + help="Print the decoded merged-error indices.", + ) + parser.add_argument( + "--verbose-search", + action="store_true", + help="Print detailed search, LP-refinement, and projection statistics during A* search.", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + if args.shot >= args.sample_num_shots: + parser.error("--shot must be smaller than --sample-num-shots.") + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + errors = merged_errors_from_dem(dem) + + dets_unpacked, obs_unpacked = sample_detections_and_observables( + circuit, + num_shots=args.sample_num_shots, + seed=args.seed, + ) + shot_dets = dets_unpacked[args.shot] + shot_obs = obs_unpacked[args.shot] + + if args.show_detections: + active_dets = np.flatnonzero(shot_dets) + print("detections:", " ".join(f"D{d}" for d in active_dets)) + + decoder = AStarPrototypeDecoder( + errors, + dem.num_detectors, + use_opt_singleton_detcost=args.opt_singleton_detcost, + respect_blocked_errors_in_heuristic=args.respect_blocked_errors_in_heuristic, + verbose_search=args.verbose_search, + ) + result = decoder.decode(shot_dets, det_beam=args.det_beam) + + print(f"heuristic: {decoder.heuristic_name}") + print(f"shot: {args.shot} / {args.sample_num_shots}") + print(f"success: {result.success}") + print(f"nodes_pushed: {result.nodes_pushed}") + print(f"nodes_popped: {result.nodes_popped}") + print(f"heuristic_calls: {result.heuristic_calls}") + print(f"plain_heuristic_calls: {result.plain_heuristic_calls}") + print(f"projection_heuristic_calls: {result.projection_heuristic_calls}") + print(f"exact_refinement_calls: {result.exact_refinement_calls}") + print(f"lp_calls: {result.lp_calls}") + print(f"lp_reinserts: {result.lp_reinserts}") + print(f"projected_nodes_generated: {result.projected_nodes_generated}") + print(f"projected_nodes_refined: {result.projected_nodes_refined}") + print(f"projected_nodes_unrefined_at_finish: {result.projected_nodes_unrefined_at_finish}") + print(f"total_lp_refinement_gain: {result.total_lp_refinement_gain:.6f}") + print(f"max_lp_refinement_gain: {result.max_lp_refinement_gain:.6f}") + print(f"elapsed_seconds: {result.elapsed_seconds:.6f}") + + if not result.success: + print("decode failed") + return 1 + + decoded_err_indices = np.flatnonzero(result.errs) + if args.show_error_indices: + print("decoded_error_indices:", " ".join(map(str, decoded_err_indices.tolist()))) + + reproduced_dets = decoder.detectors_from_errs(result.errs) + if not np.array_equal(reproduced_dets, shot_dets): + raise AssertionError("Decoded errors do not reproduce the sampled detection events.") + + reproduced_cost = decoder.cost_from_errs(result.errs) + predicted_obs = decoder.observables_from_errs(result.errs) + actual_obs = np.flatnonzero(shot_obs) + + print(f"num_decoded_errors: {int(result.errs.sum())}") + print(f"decoded_cost: {reproduced_cost:.12f}") + print("predicted_observables:", " ".join(f"L{o}" for o in predicted_obs.tolist())) + print("sampled_observables:", " ".join(f"L{o}" for o in actual_obs.tolist())) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/astar_prototype_singleton_greedy_heuristics.py b/src/py/astar/astar_prototype_singleton_greedy_heuristics.py new file mode 100644 index 0000000..42e65e5 --- /dev/null +++ b/src/py/astar/astar_prototype_singleton_greedy_heuristics.py @@ -0,0 +1,751 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder for experimenting with fast singleton-budget heuristics. + +This version mirrors the earlier Stim-based prototypes: + * load a .stim circuit, + * extract its detector error model with decompose_errors=False, + * optionally merge indistinguishable errors, + * sample detector shots from Stim, + * run precedence-pruned A* with a selectable singleton lower-bound heuristic. + +Supported heuristic choices: + plain original detector-wise feasible point + asc_deg zero-start saturation ordered by ascending detector degree + desc_plain zero-start saturation ordered by descending plain y_d + plain_sweep start from plain, then one descending saturation sweep + best_of_two max(plain_sweep, asc_deg) + best_of_three max(plain_sweep, asc_deg, desc_plain) + exact_lp exact optimal singleton LP lower bound + +The greedy heuristics are derived from feasible points of the singleton LP + + max sum_d y_d + s.t. sum_{d in T} y_d <= W(T) + y_d >= 0, + +where W(T) is the cheapest available error whose active support is T. +""" + +from __future__ import annotations + +import argparse +import heapq +import math +import time +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +import stim +from scipy.optimize import linprog +from scipy.sparse import csr_matrix + +INF = float("inf") + + +@dataclass(frozen=True) +class ErrorRecord: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class SupportData: + active_detectors: List[int] + supports: List[Tuple[Tuple[int, ...], float]] + incident: Dict[int, List[int]] + + +@dataclass +class SearchState: + errs: np.ndarray + blocked_errs: np.ndarray + dets: np.ndarray + g_cost: float + + +@dataclass +class DecodeResult: + success: bool + errs: np.ndarray + residual_dets: np.ndarray + cost: float + nodes_pushed: int + nodes_popped: int + heuristic_calls: int + elapsed_seconds: float + + +class UnionFind: + def __init__(self, n: int) -> None: + self.parent = list(range(n)) + self.rank = [0] * n + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra = self.find(a) + rb = self.find(b) + if ra == rb: + return + if self.rank[ra] < self.rank[rb]: + self.parent[ra] = rb + elif self.rank[ra] > self.rank[rb]: + self.parent[rb] = ra + else: + self.parent[rb] = ra + self.rank[ra] += 1 + + +def xor_probability(p0: float, p1: float) -> float: + return p0 * (1.0 - p1) + (1.0 - p0) * p1 + + +def iter_dem_errors_from_dem(dem: stim.DetectorErrorModel) -> Iterable[ErrorRecord]: + for instruction in dem.flattened(): + if instruction.type != "error": + continue + probability = float(instruction.args_copy()[0]) + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + f"Expected flattened error probabilities in (0, 0.5), got {probability}." + ) + + detectors: set[int] = set() + observables: set[int] = set() + for target in instruction.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected DEM target: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + + yield ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=tuple(sorted(detectors)), + observables=tuple(sorted(observables)), + ) + + +def merged_errors_from_dem(dem: stim.DetectorErrorModel) -> List[ErrorRecord]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + for error in iter_dem_errors_from_dem(dem): + key = (error.detectors, error.observables) + p_old = errors_by_symptom.get(key) + if p_old is None: + p_new = error.probability + else: + p_new = xor_probability(p_old, error.probability) + errors_by_symptom[key] = p_new + + merged: List[ErrorRecord] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + f"Merged error has probability >= 0.5 ({probability}); cannot assign positive cost." + ) + merged.append( + ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +class GreedySingletonHeuristicDecoder: + def __init__( + self, + errors: Sequence[ErrorRecord], + num_detectors: int, + num_observables: int, + *, + heuristic: str = "best_of_two", + respect_blocked_errors_in_heuristic: bool = True, + verbose_search: bool = False, + ) -> None: + self.errors = list(errors) + self.num_errors = len(self.errors) + self.num_detectors = int(num_detectors) + self.num_observables = int(num_observables) + self.heuristic_name = heuristic + self.respect_blocked_errors_in_heuristic = respect_blocked_errors_in_heuristic + self.verbose_search = verbose_search + + self.probabilities = np.array([err.probability for err in self.errors], dtype=np.float64) + self.weights = np.array([err.likelihood_cost for err in self.errors], dtype=np.float64) + self.error_detectors: List[Tuple[int, ...]] = [tuple(err.detectors) for err in self.errors] + self.error_observables: List[Tuple[int, ...]] = [tuple(err.observables) for err in self.errors] + + d2e_lists: List[List[int]] = [[] for _ in range(self.num_detectors)] + for ei, dets in enumerate(self.error_detectors): + for d in dets: + d2e_lists[d].append(ei) + self.d2e: List[np.ndarray] = [np.array(v, dtype=np.int32) for v in d2e_lists] + + self.heuristic_calls = 0 + + def reset_stats(self) -> None: + self.heuristic_calls = 0 + + def build_support_data(self, active_dets: np.ndarray, available_errs: np.ndarray) -> SupportData: + active_list = sorted(map(int, np.flatnonzero(active_dets))) + incident: Dict[int, List[int]] = {d: [] for d in active_list} + support_to_weight: Dict[Tuple[int, ...], float] = {} + + for ei in np.flatnonzero(available_errs): + ei = int(ei) + support = tuple(d for d in self.error_detectors[ei] if active_dets[d]) + if not support: + continue + weight = float(self.weights[ei]) + old = support_to_weight.get(support) + if old is None or weight < old: + support_to_weight[support] = weight + + supports = list(support_to_weight.items()) + for i, (support, _weight) in enumerate(supports): + for d in support: + if d in incident: + incident[d].append(i) + + return SupportData(active_detectors=active_list, supports=supports, incident=incident) + + def _check_coverage(self, support_data: SupportData) -> bool: + return all(len(support_data.incident[d]) > 0 for d in support_data.active_detectors) + + def heuristic_plain(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + if not support_data.active_detectors: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + y = np.zeros(self.num_detectors, dtype=np.float64) + for d in support_data.active_detectors: + best = INF + for i in support_data.incident[d]: + support, weight = support_data.supports[i] + best = min(best, weight / len(support)) + y[d] = best + return float(y[support_data.active_detectors].sum()), y + + def heuristic_saturation_zero(self, support_data: SupportData, *, order_kind: str) -> Tuple[float, Optional[np.ndarray]]: + if not support_data.active_detectors: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + + slack = np.array([weight for _support, weight in support_data.supports], dtype=np.float64) + y = np.zeros(self.num_detectors, dtype=np.float64) + + if order_kind == "asc_deg": + order = sorted(support_data.active_detectors, key=lambda d: (len(support_data.incident[d]), d)) + elif order_kind == "desc_plain": + _plain_value, y_plain = self.heuristic_plain(support_data) + if y_plain is None: + return INF, None + order = sorted(support_data.active_detectors, key=lambda d: (y_plain[d], d), reverse=True) + else: + raise ValueError(f"Unknown order_kind={order_kind!r}") + + for d in order: + value = min(slack[i] for i in support_data.incident[d]) + if value < 0: + value = 0.0 + y[d] = value + for i in support_data.incident[d]: + slack[i] -= value + return float(y[support_data.active_detectors].sum()), y + + def heuristic_plain_sweep(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + plain_value, y = self.heuristic_plain(support_data) + if y is None: + return INF, None + order = sorted(support_data.active_detectors, key=lambda d: (y[d], d), reverse=True) + for d in order: + max_feasible = min( + weight - sum(y[dd] for dd in support if dd != d) + for support, weight in support_data.supports + if d in support + ) + if max_feasible > y[d]: + y[d] = max_feasible + return float(y[support_data.active_detectors].sum()), y + + def heuristic_exact_lp(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + active = support_data.active_detectors + if not active: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + + detector_index = {d: i for i, d in enumerate(active)} + uf = UnionFind(len(active)) + for support, _weight in support_data.supports: + if len(support) > 1: + a = detector_index[support[0]] + for d in support[1:]: + uf.union(a, detector_index[d]) + + components: Dict[int, List[int]] = defaultdict(list) + for d in active: + components[uf.find(detector_index[d])].append(d) + + y = np.zeros(self.num_detectors, dtype=np.float64) + total = 0.0 + for component in components.values(): + component_set = set(component) + local = {d: i for i, d in enumerate(sorted(component))} + component_supports: List[Tuple[Tuple[int, ...], float]] = [] + for support, weight in support_data.supports: + if support[0] in component_set: + component_supports.append((tuple(local[d] for d in support), weight)) + + rows: List[int] = [] + cols: List[int] = [] + data: List[float] = [] + rhs: List[float] = [] + for r, (support, weight) in enumerate(component_supports): + rhs.append(weight) + for c in support: + rows.append(r) + cols.append(c) + data.append(1.0) + + a_ub = csr_matrix( + (data, (rows, cols)), + shape=(len(component_supports), len(component)), + dtype=np.float64, + ) + result = linprog( + c=-np.ones(len(component), dtype=np.float64), + A_ub=a_ub, + b_ub=np.array(rhs, dtype=np.float64), + bounds=[(0.0, None)] * len(component), + method="highs", + ) + if not result.success: + return INF, None + total += -float(result.fun) + for d, value in zip(sorted(component), result.x): + y[d] = float(value) + return float(total), y + + def evaluate_named_heuristic(self, support_data: SupportData, name: str) -> Tuple[float, Optional[np.ndarray]]: + if name == "plain": + return self.heuristic_plain(support_data) + if name == "asc_deg": + return self.heuristic_saturation_zero(support_data, order_kind="asc_deg") + if name == "desc_plain": + return self.heuristic_saturation_zero(support_data, order_kind="desc_plain") + if name == "plain_sweep": + return self.heuristic_plain_sweep(support_data) + if name == "best_of_two": + v1, y1 = self.heuristic_plain_sweep(support_data) + v2, y2 = self.heuristic_saturation_zero(support_data, order_kind="asc_deg") + if v1 >= v2: + return v1, y1 + return v2, y2 + if name == "best_of_three": + candidates = [ + self.heuristic_plain_sweep(support_data), + self.heuristic_saturation_zero(support_data, order_kind="asc_deg"), + self.heuristic_saturation_zero(support_data, order_kind="desc_plain"), + ] + return max(candidates, key=lambda t: t[0]) + if name == "exact_lp": + return self.heuristic_exact_lp(support_data) + raise ValueError(f"Unknown heuristic {name!r}") + + def compute_heuristic(self, dets: np.ndarray, errs: np.ndarray, blocked_errs: np.ndarray) -> float: + self.heuristic_calls += 1 + available = ~errs + if self.respect_blocked_errors_in_heuristic: + available &= ~blocked_errs + support_data = self.build_support_data(dets, available) + value, _ = self.evaluate_named_heuristic(support_data, self.heuristic_name) + return value + + def report_root_heuristics(self, dets: np.ndarray, errs: np.ndarray, blocked_errs: np.ndarray) -> List[Tuple[str, float]]: + available = ~errs + if self.respect_blocked_errors_in_heuristic: + available &= ~blocked_errs + support_data = self.build_support_data(dets, available) + names = ["plain", "asc_deg", "desc_plain", "plain_sweep", "best_of_two", "best_of_three", "exact_lp"] + out: List[Tuple[str, float]] = [] + for name in names: + value, _ = self.evaluate_named_heuristic(support_data, name) + out.append((name, value)) + return out + + def decode(self, shot_dets: np.ndarray, det_beam: float = INF) -> DecodeResult: + start_time = time.perf_counter() + self.reset_stats() + + dets0 = np.array(shot_dets, dtype=bool, copy=True) + errs0 = np.zeros(self.num_errors, dtype=bool) + blocked0 = np.zeros(self.num_errors, dtype=bool) + h0 = self.compute_heuristic(dets0, errs0, blocked0) + if math.isinf(h0): + return DecodeResult( + success=False, + errs=errs0, + residual_dets=dets0, + cost=INF, + nodes_pushed=1, + nodes_popped=0, + heuristic_calls=self.heuristic_calls, + elapsed_seconds=time.perf_counter() - start_time, + ) + + heap: List[Tuple[float, int, int, SearchState]] = [] + counter = 0 + root_state = SearchState(errs=errs0, blocked_errs=blocked0, dets=dets0, g_cost=0.0) + heapq.heappush(heap, (h0, int(dets0.sum()), counter, root_state)) + counter += 1 + nodes_pushed = 1 + nodes_popped = 0 + min_num_dets = int(dets0.sum()) + + while heap: + f_cost, num_dets, _entry_id, state = heapq.heappop(heap) + nodes_popped += 1 + max_num_dets = min_num_dets + det_beam + if num_dets > max_num_dets: + continue + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = min_num_dets + det_beam + + if self.verbose_search: + print( + f"len(heap)={len(heap)} nodes_pushed={nodes_pushed} nodes_popped={nodes_popped} " + f"num_dets={num_dets} max_num_dets={max_num_dets} f={f_cost:.6f} g={state.g_cost:.6f}" + ) + + if num_dets == 0: + return DecodeResult( + success=True, + errs=state.errs, + residual_dets=state.dets, + cost=state.g_cost, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + heuristic_calls=self.heuristic_calls, + elapsed_seconds=time.perf_counter() - start_time, + ) + + min_det = int(np.flatnonzero(state.dets)[0]) + prefix_blocked = state.blocked_errs.copy() + children_generated = 0 + children_beam_pruned = 0 + children_infeasible = 0 + + for ei in self.d2e[min_det]: + ei = int(ei) + prefix_blocked[ei] = True + if state.errs[ei] or state.blocked_errs[ei]: + continue + + child_errs = state.errs.copy() + child_errs[ei] = True + child_blocked = prefix_blocked.copy() + child_dets = state.dets.copy() + for d in self.error_detectors[ei]: + child_dets[d] ^= True + child_num_dets = int(child_dets.sum()) + if child_num_dets > max_num_dets: + children_beam_pruned += 1 + continue + child_g = state.g_cost + float(self.weights[ei]) + child_h = self.compute_heuristic(child_dets, child_errs, child_blocked) + if math.isinf(child_h): + children_infeasible += 1 + continue + child_state = SearchState( + errs=child_errs, + blocked_errs=child_blocked, + dets=child_dets, + g_cost=child_g, + ) + heapq.heappush(heap, (child_g + child_h, child_num_dets, counter, child_state)) + counter += 1 + nodes_pushed += 1 + children_generated += 1 + + if self.verbose_search: + print( + f" expanded children_generated={children_generated} beam_pruned={children_beam_pruned} " + f"infeasible={children_infeasible}" + ) + + return DecodeResult( + success=False, + errs=np.zeros(self.num_errors, dtype=bool), + residual_dets=np.array(shot_dets, dtype=bool, copy=True), + cost=INF, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + heuristic_calls=self.heuristic_calls, + elapsed_seconds=time.perf_counter() - start_time, + ) + + def cost_from_errs(self, errs: np.ndarray) -> float: + return float(self.weights[errs].sum()) + + def detectors_from_errs(self, errs: np.ndarray) -> np.ndarray: + dets = np.zeros(self.num_detectors, dtype=bool) + for ei in np.flatnonzero(errs): + for d in self.error_detectors[int(ei)]: + dets[d] ^= True + return dets + + def observables_from_errs(self, errs: np.ndarray) -> np.ndarray: + parity: Dict[int, bool] = {} + for ei in np.flatnonzero(errs): + for obs in self.error_observables[int(ei)]: + parity[int(obs)] = not parity.get(int(obs), False) + return np.array(sorted(obs for obs, bit in parity.items() if bit), dtype=np.int32) + + +def sample_detections_and_observables( + circuit: stim.Circuit, + *, + num_shots: int, + seed: int, + num_detectors: int, + num_observables: int, +) -> Tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets_packed, obs_packed = sampler.sample( + shots=num_shots, + separate_observables=True, + bit_packed=True, + ) + dets_unpacked = np.unpackbits( + dets_packed, + bitorder="little", + axis=1, + count=num_detectors, + ) + obs_unpacked = np.unpackbits( + obs_packed, + bitorder="little", + axis=1, + count=num_observables, + ) + return dets_unpacked.astype(bool), obs_unpacked.astype(bool) + + +def parse_det_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "infinity", "none"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("det-beam must be non-negative or 'inf'.") + return float(value) + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Prototype A* decoder for Stim circuits using greedy singleton-budget heuristics." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a .stim circuit file.") + parser.add_argument( + "--sample-num-shots", + type=int, + default=100, + help="Number of shots to sample from Stim before selecting --shot (default: 100).", + ) + parser.add_argument( + "--shot", + type=int, + default=0, + help="Index of the sampled shot to decode (default: 0).", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Stim sampler seed (default: 27123839530).", + ) + parser.add_argument( + "--det-beam", + type=parse_det_beam, + default=INF, + help="Beam cutoff on the residual detector count; use 'inf' for none.", + ) + parser.add_argument( + "--heuristic", + choices=["plain", "asc_deg", "desc_plain", "plain_sweep", "best_of_two", "best_of_three", "exact_lp"], + default="best_of_two", + help="Lower-bound heuristic to use during A* search (default: best_of_two).", + ) + parser.add_argument( + "--merge-errors", + action=argparse.BooleanOptionalAction, + default=True, + help="Merge indistinguishable DEM errors before decoding (default: enabled).", + ) + parser.add_argument( + "--respect-blocked-errors-in-heuristic", + action=argparse.BooleanOptionalAction, + default=True, + help="Exclude precedence-blocked errors when forming the lower bound (default: enabled).", + ) + parser.add_argument( + "--report-all-root-heuristics", + action="store_true", + help="Print all root-node heuristic values, including exact_lp, for the selected shot.", + ) + parser.add_argument( + "--skip-decode", + action="store_true", + help="Only report root heuristics; do not run A* search.", + ) + parser.add_argument( + "--show-shot-detectors", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the selected shot's active detector IDs (default: enabled).", + ) + parser.add_argument( + "--show-error-indices", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the decoded merged-error indices when decoding succeeds (default: enabled).", + ) + parser.add_argument( + "--verbose-search", + action="store_true", + help="Print per-node search diagnostics.", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + if args.shot >= args.sample_num_shots: + parser.error("--shot must be smaller than --sample-num-shots.") + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + errors = merged_errors_from_dem(dem) if args.merge_errors else list(iter_dem_errors_from_dem(dem)) + + dets, obs = sample_detections_and_observables( + circuit, + num_shots=args.sample_num_shots, + seed=args.seed, + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + ) + shot_dets = dets[args.shot] + shot_obs = obs[args.shot] + + decoder = GreedySingletonHeuristicDecoder( + errors, + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + heuristic=args.heuristic, + respect_blocked_errors_in_heuristic=args.respect_blocked_errors_in_heuristic, + verbose_search=args.verbose_search, + ) + + print(f"circuit = {args.circuit}") + print(f"heuristic = {args.heuristic}") + print(f"sample_num_shots = {args.sample_num_shots}") + print(f"shot = {args.shot}") + print(f"num_errors = {decoder.num_errors}") + print(f"num_detectors = {decoder.num_detectors}") + print(f"num_observables = {decoder.num_observables}") + print(f"det_beam = {args.det_beam}") + print(f"merge_errors = {args.merge_errors}") + print(f"respect_blocked_errors_in_heuristic = {args.respect_blocked_errors_in_heuristic}") + + if args.show_shot_detectors: + active_dets = np.flatnonzero(shot_dets) + print("shot_detectors =", " ".join(f"D{d}" for d in active_dets)) + + if args.report_all_root_heuristics: + root_errs = np.zeros(decoder.num_errors, dtype=bool) + root_blocked = np.zeros(decoder.num_errors, dtype=bool) + report = decoder.report_root_heuristics(shot_dets, root_errs, root_blocked) + exact = next((v for k, v in report if k == "exact_lp"), None) + print("root_heuristics:") + for name, value in report: + if exact is not None and not math.isinf(exact) and exact > 0: + ratio = value / exact if not math.isinf(value) else INF + ratio_text = "INF" if math.isinf(ratio) else f"{ratio:.6f}" + else: + ratio_text = "n/a" + value_text = "INF" if math.isinf(value) else f"{value:.12f}" + print(f" {name:>12s} value={value_text} ratio_to_exact={ratio_text}") + + if args.skip_decode: + return 0 + + result = decoder.decode(shot_dets, det_beam=args.det_beam) + print(f"success = {result.success}") + print(f"nodes_pushed = {result.nodes_pushed}") + print(f"nodes_popped = {result.nodes_popped}") + print(f"heuristic_calls = {result.heuristic_calls}") + print(f"elapsed_seconds = {result.elapsed_seconds:.6f}") + + if not result.success: + print("decode failed") + return 1 + + if args.show_error_indices: + print("decoded_error_indices =", " ".join(map(str, np.flatnonzero(result.errs).tolist()))) + + reproduced_dets = decoder.detectors_from_errs(result.errs) + if not np.array_equal(reproduced_dets, shot_dets): + raise AssertionError("Decoded errors do not reproduce the sampled detection events.") + + decoded_cost = decoder.cost_from_errs(result.errs) + predicted_obs = decoder.observables_from_errs(result.errs) + sampled_obs = np.flatnonzero(shot_obs) + + print(f"num_decoded_errors = {int(result.errs.sum())}") + print(f"decoded_cost = {decoded_cost:.12f}") + print("predicted_observables =", " ".join(f"L{o}" for o in predicted_obs.tolist())) + print("sampled_observables =", " ".join(f"L{o}" for o in sampled_obs.tolist())) + print(f"observables_match = {bool(np.array_equal(predicted_obs, sampled_obs))}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/astar_prototype_singleton_greedy_heuristics_lazy.py b/src/py/astar/astar_prototype_singleton_greedy_heuristics_lazy.py new file mode 100644 index 0000000..346a974 --- /dev/null +++ b/src/py/astar/astar_prototype_singleton_greedy_heuristics_lazy.py @@ -0,0 +1,1128 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder for Stim circuits using greedy singleton-budget heuristics. + +This version keeps the same Stim-facing API as the earlier greedy prototype but +adds lazy reinsertion / parent-y projection, in the same spirit as the lazy +optimal-singleton prototype: + + * nodes are seeded with a cheap feasible lower bound; + * when a node is popped, the selected heuristic is evaluated on that node; + * if the refined heuristic raises the node key, the node is reinserted; + * expanded nodes project their current feasible y-prices onto children; + * optionally, the projected child bound is maxed with plain detcost. + +Supported heuristic choices: + plain original detector-wise feasible point + asc_deg zero-start saturation ordered by ascending detector degree + desc_plain zero-start saturation ordered by descending plain y_d + plain_sweep start from plain, then one descending saturation sweep + best_of_two max(plain_sweep, asc_deg) + best_of_three max(plain_sweep, asc_deg, desc_plain) + exact_lp exact optimal singleton LP lower bound + +When --lazy-reinsert-heuristics is enabled (the default), the root is seeded by +plain detcost and only popped nodes are refined with the selected heuristic. +This works for all of the above heuristics because each returns a feasible +singleton-budget vector y, and projecting that y to a child by keeping prices +on detectors that remain active and zeroing newly active detectors is still a +feasible child singleton-budget point. +""" + +from __future__ import annotations + +import argparse +import heapq +import math +import time +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +import stim +from scipy.optimize import linprog +from scipy.sparse import csr_matrix + +INF = float("inf") +HEURISTIC_EPS = 1e-9 + + +@dataclass(frozen=True) +class ErrorRecord: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class SupportData: + active_detectors: List[int] + supports: List[Tuple[Tuple[int, ...], float]] + incident: Dict[int, List[int]] + + +@dataclass +class SearchState: + errs: np.ndarray + blocked_errs: np.ndarray + dets: np.ndarray + det_counts: np.ndarray + g_cost: float + h_cost: float + h_source: str + refined: bool + y_prices: Optional[np.ndarray] + + +@dataclass +class DecodeResult: + success: bool + errs: np.ndarray + residual_dets: np.ndarray + cost: float + nodes_pushed: int + nodes_popped: int + max_queue_size: int + heuristic_calls: int + plain_heuristic_calls: int + projection_heuristic_calls: int + refinement_calls: int + lp_calls: int + reinserts: int + projected_nodes_generated: int + projected_nodes_refined: int + projected_nodes_unrefined_at_finish: int + total_refinement_gain: float + max_refinement_gain: float + elapsed_seconds: float + + +class UnionFind: + def __init__(self, n: int) -> None: + self.parent = list(range(n)) + self.rank = [0] * n + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra = self.find(a) + rb = self.find(b) + if ra == rb: + return + if self.rank[ra] < self.rank[rb]: + self.parent[ra] = rb + elif self.rank[ra] > self.rank[rb]: + self.parent[rb] = ra + else: + self.parent[rb] = ra + self.rank[ra] += 1 + + +def xor_probability(p0: float, p1: float) -> float: + return p0 * (1.0 - p1) + (1.0 - p0) * p1 + + +def iter_dem_errors_from_dem(dem: stim.DetectorErrorModel) -> Iterable[ErrorRecord]: + for instruction in dem.flattened(): + if instruction.type != "error": + continue + probability = float(instruction.args_copy()[0]) + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + f"Expected flattened error probabilities in (0, 0.5), got {probability}." + ) + + detectors: set[int] = set() + observables: set[int] = set() + for target in instruction.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected DEM target: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + + yield ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=tuple(sorted(detectors)), + observables=tuple(sorted(observables)), + ) + + +def merged_errors_from_dem(dem: stim.DetectorErrorModel) -> List[ErrorRecord]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + for error in iter_dem_errors_from_dem(dem): + key = (error.detectors, error.observables) + p_old = errors_by_symptom.get(key) + if p_old is None: + p_new = error.probability + else: + p_new = xor_probability(p_old, error.probability) + errors_by_symptom[key] = p_new + + merged: List[ErrorRecord] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + f"Merged error has probability >= 0.5 ({probability}); cannot assign positive cost." + ) + merged.append( + ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +class GreedySingletonHeuristicDecoder: + def __init__( + self, + errors: Sequence[ErrorRecord], + num_detectors: int, + num_observables: int, + *, + heuristic: str = "best_of_two", + respect_blocked_errors_in_heuristic: bool = True, + lazy_reinsert_heuristics: bool = True, + projection_combine_max_plain: bool = True, + verbose_search: bool = False, + ) -> None: + self.errors = list(errors) + self.num_errors = len(self.errors) + self.num_detectors = int(num_detectors) + self.num_observables = int(num_observables) + self.heuristic_name = heuristic + self.respect_blocked_errors_in_heuristic = respect_blocked_errors_in_heuristic + self.lazy_reinsert_heuristics = lazy_reinsert_heuristics + self.projection_combine_max_plain = projection_combine_max_plain + self.verbose_search = verbose_search + + self.probabilities = np.array([err.probability for err in self.errors], dtype=np.float64) + self.weights = np.array([err.likelihood_cost for err in self.errors], dtype=np.float64) + self.error_detectors: List[Tuple[int, ...]] = [tuple(err.detectors) for err in self.errors] + self.error_observables: List[Tuple[int, ...]] = [tuple(err.observables) for err in self.errors] + + d2e_lists: List[List[int]] = [[] for _ in range(self.num_detectors)] + for ei, dets in enumerate(self.error_detectors): + for d in dets: + d2e_lists[d].append(ei) + self.d2e: List[np.ndarray] = [np.array(v, dtype=np.int32) for v in d2e_lists] + + self.reset_stats() + + def reset_stats(self) -> None: + self.heuristic_calls = 0 + self.plain_heuristic_calls = 0 + self.projection_heuristic_calls = 0 + self.refinement_calls = 0 + self.lp_calls = 0 + self.reinserts = 0 + self.projected_nodes_generated = 0 + self.projected_nodes_refined = 0 + self.total_refinement_gain = 0.0 + self.max_refinement_gain = 0.0 + + @property + def mode_name(self) -> str: + if self.heuristic_name == "plain": + return "plain" + if self.lazy_reinsert_heuristics: + suffix = "-lazy-projection" + if self.projection_combine_max_plain: + suffix += "-maxplain" + return f"{self.heuristic_name}{suffix}" + return self.heuristic_name + + def _available_errors(self, errs: np.ndarray, blocked_errs: np.ndarray) -> np.ndarray: + available = ~errs + if self.respect_blocked_errors_in_heuristic: + available &= ~blocked_errs + return available + + def _has_cover_for_all_active_detectors(self, dets: np.ndarray, available_errs: np.ndarray) -> bool: + for d in np.flatnonzero(dets): + found = False + for ei in self.d2e[int(d)]: + if available_errs[int(ei)]: + found = True + break + if not found: + return False + return True + + def build_support_data(self, active_dets: np.ndarray, available_errs: np.ndarray) -> SupportData: + active_list = sorted(map(int, np.flatnonzero(active_dets))) + incident: Dict[int, List[int]] = {d: [] for d in active_list} + support_to_weight: Dict[Tuple[int, ...], float] = {} + + for ei in np.flatnonzero(available_errs): + ei = int(ei) + support = tuple(d for d in self.error_detectors[ei] if active_dets[d]) + if not support: + continue + weight = float(self.weights[ei]) + old = support_to_weight.get(support) + if old is None or weight < old: + support_to_weight[support] = weight + + supports = list(support_to_weight.items()) + for i, (support, _weight) in enumerate(supports): + for d in support: + if d in incident: + incident[d].append(i) + + return SupportData(active_detectors=active_list, supports=supports, incident=incident) + + def _check_coverage(self, support_data: SupportData) -> bool: + return all(len(support_data.incident[d]) > 0 for d in support_data.active_detectors) + + def plain_detcost_from_counts( + self, + dets: np.ndarray, + available_errs: np.ndarray, + det_counts: np.ndarray, + ) -> Tuple[float, Optional[np.ndarray]]: + self.heuristic_calls += 1 + self.plain_heuristic_calls += 1 + active = np.flatnonzero(dets) + if active.size == 0: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + + y = np.zeros(self.num_detectors, dtype=np.float64) + total = 0.0 + for d in active: + best = INF + for ei in self.d2e[int(d)]: + ei = int(ei) + if not available_errs[ei]: + continue + count = int(det_counts[ei]) + assert count > 0 + value = self.weights[ei] / count + if value < best: + best = value + if math.isinf(best): + return INF, None + y[int(d)] = best + total += best + return total, y + + def heuristic_plain(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + if not support_data.active_detectors: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + y = np.zeros(self.num_detectors, dtype=np.float64) + for d in support_data.active_detectors: + best = INF + for i in support_data.incident[d]: + support, weight = support_data.supports[i] + best = min(best, weight / len(support)) + y[d] = best + return float(y[support_data.active_detectors].sum()), y + + def heuristic_saturation_zero(self, support_data: SupportData, *, order_kind: str) -> Tuple[float, Optional[np.ndarray]]: + if not support_data.active_detectors: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + + slack = np.array([weight for _support, weight in support_data.supports], dtype=np.float64) + y = np.zeros(self.num_detectors, dtype=np.float64) + + if order_kind == "asc_deg": + order = sorted(support_data.active_detectors, key=lambda d: (len(support_data.incident[d]), d)) + elif order_kind == "desc_plain": + _plain_value, y_plain = self.heuristic_plain(support_data) + if y_plain is None: + return INF, None + order = sorted(support_data.active_detectors, key=lambda d: (y_plain[d], d), reverse=True) + else: + raise ValueError(f"Unknown order_kind={order_kind!r}") + + for d in order: + value = min(slack[i] for i in support_data.incident[d]) + if value < 0: + value = 0.0 + y[d] = value + for i in support_data.incident[d]: + slack[i] -= value + return float(y[support_data.active_detectors].sum()), y + + def heuristic_plain_sweep(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + plain_value, y = self.heuristic_plain(support_data) + if y is None: + return INF, None + order = sorted(support_data.active_detectors, key=lambda d: (y[d], d), reverse=True) + for d in order: + max_feasible = min( + weight - sum(y[dd] for dd in support if dd != d) + for support, weight in support_data.supports + if d in support + ) + if max_feasible > y[d]: + y[d] = max_feasible + return float(y[support_data.active_detectors].sum()), y + + def heuristic_exact_lp(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + active = support_data.active_detectors + if not active: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + + detector_index = {d: i for i, d in enumerate(active)} + uf = UnionFind(len(active)) + for support, _weight in support_data.supports: + if len(support) > 1: + a = detector_index[support[0]] + for d in support[1:]: + uf.union(a, detector_index[d]) + + components: Dict[int, List[int]] = defaultdict(list) + for d in active: + components[uf.find(detector_index[d])].append(d) + + y = np.zeros(self.num_detectors, dtype=np.float64) + total = 0.0 + for component in components.values(): + component_set = set(component) + local = {d: i for i, d in enumerate(sorted(component))} + component_supports: List[Tuple[Tuple[int, ...], float]] = [] + for support, weight in support_data.supports: + if support[0] in component_set: + component_supports.append((tuple(local[d] for d in support), weight)) + + rows: List[int] = [] + cols: List[int] = [] + data: List[float] = [] + rhs: List[float] = [] + for r, (support, weight) in enumerate(component_supports): + rhs.append(weight) + for c in support: + rows.append(r) + cols.append(c) + data.append(1.0) + + a_ub = csr_matrix( + (data, (rows, cols)), + shape=(len(component_supports), len(component)), + dtype=np.float64, + ) + self.lp_calls += 1 + result = linprog( + c=-np.ones(len(component), dtype=np.float64), + A_ub=a_ub, + b_ub=np.array(rhs, dtype=np.float64), + bounds=[(0.0, None)] * len(component), + method="highs", + ) + if not result.success: + return INF, None + total += -float(result.fun) + for d, value in zip(sorted(component), result.x): + y[d] = float(value) + return float(total), y + + def evaluate_named_heuristic(self, support_data: SupportData, name: str) -> Tuple[float, Optional[np.ndarray]]: + if name == "plain": + return self.heuristic_plain(support_data) + if name == "asc_deg": + return self.heuristic_saturation_zero(support_data, order_kind="asc_deg") + if name == "desc_plain": + return self.heuristic_saturation_zero(support_data, order_kind="desc_plain") + if name == "plain_sweep": + return self.heuristic_plain_sweep(support_data) + if name == "best_of_two": + v1, y1 = self.heuristic_plain_sweep(support_data) + v2, y2 = self.heuristic_saturation_zero(support_data, order_kind="asc_deg") + if v1 >= v2: + return v1, y1 + return v2, y2 + if name == "best_of_three": + candidates = [ + self.heuristic_plain_sweep(support_data), + self.heuristic_saturation_zero(support_data, order_kind="asc_deg"), + self.heuristic_saturation_zero(support_data, order_kind="desc_plain"), + ] + return max(candidates, key=lambda t: t[0]) + if name == "exact_lp": + return self.heuristic_exact_lp(support_data) + raise ValueError(f"Unknown heuristic {name!r}") + + def compute_support_based_heuristic( + self, + dets: np.ndarray, + errs: np.ndarray, + blocked_errs: np.ndarray, + *, + name: Optional[str] = None, + ) -> Tuple[float, Optional[np.ndarray]]: + self.heuristic_calls += 1 + available = self._available_errors(errs, blocked_errs) + support_data = self.build_support_data(dets, available) + return self.evaluate_named_heuristic(support_data, name or self.heuristic_name) + + def project_child_y( + self, + parent_state: SearchState, + child_dets: np.ndarray, + child_errs: np.ndarray, + child_blocked_errs: np.ndarray, + child_det_counts: np.ndarray, + flipped_detectors: Sequence[int], + ) -> Tuple[float, Optional[np.ndarray], str]: + if parent_state.y_prices is None: + raise AssertionError("Expected a stored feasible y vector before projecting to a child.") + + self.heuristic_calls += 1 + self.projection_heuristic_calls += 1 + available = self._available_errors(child_errs, child_blocked_errs) + if not self._has_cover_for_all_active_detectors(child_dets, available): + return INF, None, "projected" + + y_projected = np.zeros(self.num_detectors, dtype=np.float64) + keep = parent_state.dets & child_dets + y_projected[keep] = parent_state.y_prices[keep] + projected_value = float(y_projected[np.flatnonzero(child_dets)].sum()) + best_value = projected_value + best_y = y_projected + best_source = "projected" + + if self.projection_combine_max_plain: + plain_value, plain_y = self.plain_detcost_from_counts(child_dets, available, child_det_counts) + if plain_y is None: + return INF, None, "plain" + if plain_value > best_value + HEURISTIC_EPS: + best_value = plain_value + best_y = plain_y + best_source = "plain" + + return best_value, best_y, best_source + + def report_root_heuristics(self, dets: np.ndarray, errs: np.ndarray, blocked_errs: np.ndarray) -> List[Tuple[str, float]]: + available = self._available_errors(errs, blocked_errs) + support_data = self.build_support_data(dets, available) + names = ["plain", "asc_deg", "desc_plain", "plain_sweep", "best_of_two", "best_of_three", "exact_lp"] + out: List[Tuple[str, float]] = [] + saved_lp_calls = self.lp_calls + for name in names: + value, _ = self.evaluate_named_heuristic(support_data, name) + out.append((name, value)) + self.lp_calls = saved_lp_calls + return out + + def _maybe_refine_node(self, state: SearchState) -> Tuple[SearchState, bool]: + if state.refined or self.heuristic_name == "plain" or not self.lazy_reinsert_heuristics: + return state, False + + previous_source = state.h_source + self.refinement_calls += 1 + new_value, new_y = self.compute_support_based_heuristic( + state.dets, + state.errs, + state.blocked_errs, + name=self.heuristic_name, + ) + if new_y is None: + if previous_source == "projected": + self.projected_nodes_refined += 1 + if self.verbose_search: + print( + f" refine approx_h={state.h_cost:.6f} new_h=INF delta=INF reinserted=False discarded=True" + ) + state.refined = True + return state, True + + delta = new_value - state.h_cost + self.total_refinement_gain += max(0.0, delta) + self.max_refinement_gain = max(self.max_refinement_gain, max(0.0, delta)) + + if self.heuristic_name == "exact_lp" and new_value + 1e-7 < state.h_cost: + raise AssertionError( + f"Exact LP value {new_value} below stored projected value {state.h_cost}." + ) + + if new_value > state.h_cost + HEURISTIC_EPS: + if previous_source == "projected": + self.projected_nodes_refined += 1 + state.h_cost = new_value + state.h_source = "refined" + state.y_prices = new_y + state.refined = True + self.reinserts += 1 + if self.verbose_search: + print( + f" refine approx_h={state.h_cost - delta:.6f} new_h={new_value:.6f} delta={delta:.6f} reinserted=True discarded=False" + ) + return state, True + + # Non-improving greedy recomputation: keep the existing projected feasible point. + if previous_source == "projected": + self.projected_nodes_refined += 1 + if abs(new_value - state.h_cost) <= HEURISTIC_EPS: + state.y_prices = new_y + state.refined = True + if self.verbose_search: + new_text = "INF" if math.isinf(new_value) else f"{new_value:.6f}" + print( + f" refine approx_h={state.h_cost:.6f} new_h={new_text} delta={delta:.6f} reinserted=False discarded=False" + ) + return state, False + + def decode(self, shot_dets: np.ndarray, det_beam: float = INF) -> DecodeResult: + start_time = time.perf_counter() + self.reset_stats() + + dets0 = np.array(shot_dets, dtype=bool, copy=True) + errs0 = np.zeros(self.num_errors, dtype=bool) + blocked0 = np.zeros(self.num_errors, dtype=bool) + det_counts0 = np.zeros(self.num_errors, dtype=np.uint16) + for d in np.flatnonzero(dets0): + for ei in self.d2e[int(d)]: + det_counts0[int(ei)] += 1 + + root_h, root_y = self.plain_detcost_from_counts(dets0, self._available_errors(errs0, blocked0), det_counts0) + if root_y is None or math.isinf(root_h): + return DecodeResult( + success=False, + errs=errs0, + residual_dets=dets0, + cost=INF, + nodes_pushed=1, + nodes_popped=0, + max_queue_size=1, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + root_refined = (self.heuristic_name == "plain") or (not self.lazy_reinsert_heuristics) + if root_refined and self.heuristic_name != "plain": + # Eager mode: use the selected heuristic immediately. + eager_h, eager_y = self.compute_support_based_heuristic(dets0, errs0, blocked0, name=self.heuristic_name) + if eager_y is None or math.isinf(eager_h): + return DecodeResult( + success=False, + errs=errs0, + residual_dets=dets0, + cost=INF, + nodes_pushed=1, + nodes_popped=0, + max_queue_size=1, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + root_h, root_y = eager_h, eager_y + + root_state = SearchState( + errs=errs0, + blocked_errs=blocked0, + dets=dets0, + det_counts=det_counts0, + g_cost=0.0, + h_cost=root_h, + h_source="plain" if not root_refined else ("plain" if self.heuristic_name == "plain" else "refined"), + refined=root_refined, + y_prices=root_y, + ) + + heap: List[Tuple[float, int, int, SearchState]] = [] + counter = 0 + heapq.heappush(heap, (root_state.g_cost + root_state.h_cost, int(dets0.sum()), counter, root_state)) + counter += 1 + nodes_pushed = 1 + nodes_popped = 0 + max_queue_size = 1 + min_num_dets = int(dets0.sum()) + + while heap: + max_queue_size = max(max_queue_size, len(heap)) + f_cost, num_dets, _entry_id, state = heapq.heappop(heap) + nodes_popped += 1 + max_num_dets = min_num_dets + det_beam + if num_dets > max_num_dets: + continue + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = min_num_dets + det_beam + + if self.verbose_search: + projected_unrefined = self.projected_nodes_generated - self.projected_nodes_refined + print( + f"len(heap)={len(heap)} nodes_pushed={nodes_pushed} nodes_popped={nodes_popped} " + f"lp_calls={self.lp_calls} reinserts={self.reinserts} proj_generated={self.projected_nodes_generated} " + f"proj_refined={self.projected_nodes_refined} proj_unrefined_so_far={projected_unrefined} " + f"num_dets={num_dets} max_num_dets={max_num_dets} f={f_cost:.6f} g={state.g_cost:.6f} " + f"h={state.h_cost:.6f} h_source={state.h_source} refined={state.refined}" + ) + + if num_dets == 0: + return DecodeResult( + success=True, + errs=state.errs, + residual_dets=state.dets, + cost=state.g_cost, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + max_queue_size=max_queue_size, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + state, should_reinsert = self._maybe_refine_node(state) + if should_reinsert: + if state.y_prices is None or math.isinf(state.h_cost): + if state.h_source == "projected": + self.projected_nodes_refined += 1 + continue + if state.h_source != "plain": + heapq.heappush(heap, (state.g_cost + state.h_cost, num_dets, counter, state)) + counter += 1 + continue + + min_det = int(np.flatnonzero(state.dets)[0]) + prefix_blocked = state.blocked_errs.copy() + children_generated = 0 + children_beam_pruned = 0 + children_infeasible = 0 + children_projected = 0 + + for ei in self.d2e[min_det]: + ei = int(ei) + prefix_blocked[ei] = True + if state.errs[ei] or state.blocked_errs[ei]: + continue + + child_errs = state.errs.copy() + child_errs[ei] = True + child_blocked = prefix_blocked.copy() + child_dets = state.dets.copy() + child_det_counts = state.det_counts.copy() + for d in self.error_detectors[ei]: + d = int(d) + if child_dets[d]: + child_dets[d] = False + for oei in self.d2e[d]: + child_det_counts[int(oei)] -= 1 + else: + child_dets[d] = True + for oei in self.d2e[d]: + child_det_counts[int(oei)] += 1 + + child_num_dets = int(child_dets.sum()) + if child_num_dets > max_num_dets: + children_beam_pruned += 1 + continue + + child_g = state.g_cost + float(self.weights[ei]) + if self.heuristic_name == "plain" or (not self.lazy_reinsert_heuristics): + child_h, child_y = self.compute_support_based_heuristic( + child_dets, child_errs, child_blocked, name=self.heuristic_name + ) + child_source = "plain" if self.heuristic_name == "plain" else "refined" + child_refined = True + else: + if state.y_prices is None: + raise AssertionError("Expected parent feasible y-prices before projecting to child.") + child_h, child_y, child_source = self.project_child_y( + state, + child_dets, + child_errs, + child_blocked, + child_det_counts, + self.error_detectors[ei], + ) + self.projected_nodes_generated += 1 + children_projected += 1 + child_refined = False + + if child_y is None or math.isinf(child_h): + children_infeasible += 1 + continue + + child_state = SearchState( + errs=child_errs, + blocked_errs=child_blocked, + dets=child_dets, + det_counts=child_det_counts, + g_cost=child_g, + h_cost=child_h, + h_source=child_source, + refined=child_refined, + y_prices=child_y, + ) + heapq.heappush(heap, (child_g + child_h, child_num_dets, counter, child_state)) + counter += 1 + nodes_pushed += 1 + children_generated += 1 + + if self.verbose_search: + projected_unrefined = self.projected_nodes_generated - self.projected_nodes_refined + print( + f" expanded children_generated={children_generated} children_projected={children_projected} " + f"beam_pruned={children_beam_pruned} infeasible={children_infeasible} " + f"lp_calls={self.lp_calls} proj_unrefined_so_far={projected_unrefined}" + ) + + return DecodeResult( + success=False, + errs=np.zeros(self.num_errors, dtype=bool), + residual_dets=np.array(shot_dets, dtype=bool, copy=True), + cost=INF, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + max_queue_size=max_queue_size, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + def cost_from_errs(self, errs: np.ndarray) -> float: + return float(self.weights[errs].sum()) + + def detectors_from_errs(self, errs: np.ndarray) -> np.ndarray: + dets = np.zeros(self.num_detectors, dtype=bool) + for ei in np.flatnonzero(errs): + for d in self.error_detectors[int(ei)]: + dets[d] ^= True + return dets + + def observables_from_errs(self, errs: np.ndarray) -> np.ndarray: + parity: Dict[int, bool] = {} + for ei in np.flatnonzero(errs): + for obs in self.error_observables[int(ei)]: + parity[int(obs)] = not parity.get(int(obs), False) + return np.array(sorted(obs for obs, bit in parity.items() if bit), dtype=np.int32) + + +def sample_detections_and_observables( + circuit: stim.Circuit, + *, + num_shots: int, + seed: int, + num_detectors: int, + num_observables: int, +) -> Tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets_packed, obs_packed = sampler.sample( + shots=num_shots, + separate_observables=True, + bit_packed=True, + ) + dets_unpacked = np.unpackbits( + dets_packed, + bitorder="little", + axis=1, + count=num_detectors, + ) + obs_unpacked = np.unpackbits( + obs_packed, + bitorder="little", + axis=1, + count=num_observables, + ) + return dets_unpacked.astype(bool), obs_unpacked.astype(bool) + + +def parse_det_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "infinity", "none"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("det-beam must be non-negative or 'inf'.") + return float(value) + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Prototype A* decoder for Stim circuits using greedy singleton-budget heuristics." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a .stim circuit file.") + parser.add_argument( + "--dets", + type=str, + default=None, + help="String of shot dets (e.g., 'shot D0 D1 L2') to parse instead of sampling.", + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=100, + help="Number of shots to sample from Stim before selecting --shot (default: 100).", + ) + parser.add_argument( + "--shot", + type=int, + default=0, + help="Index of the sampled shot to decode (default: 0).", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Stim sampler seed (default: 27123839530).", + ) + parser.add_argument( + "--det-beam", + type=parse_det_beam, + default=INF, + help="Beam cutoff on the residual detector count; use 'inf' for none.", + ) + parser.add_argument( + "--heuristic", + choices=["plain", "asc_deg", "desc_plain", "plain_sweep", "best_of_two", "best_of_three", "exact_lp"], + default="best_of_two", + help="Lower-bound heuristic to use during A* search (default: best_of_two).", + ) + parser.add_argument( + "--lazy-reinsert-heuristics", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "For non-plain heuristics, seed nodes with plain detcost, refine on pop, and reinsert when the selected " + "heuristic improves the key (default: enabled)." + ), + ) + parser.add_argument( + "--projection-combine-max-plain", + action=argparse.BooleanOptionalAction, + default=True, + help="When projecting parent y-prices to a child, take max(projected, plain detcost) (default: enabled).", + ) + parser.add_argument( + "--merge-errors", + action=argparse.BooleanOptionalAction, + default=True, + help="Merge indistinguishable DEM errors before decoding (default: enabled).", + ) + parser.add_argument( + "--respect-blocked-errors-in-heuristic", + action=argparse.BooleanOptionalAction, + default=True, + help="Exclude precedence-blocked errors when forming the lower bound (default: enabled).", + ) + parser.add_argument( + "--report-all-root-heuristics", + action="store_true", + help="Print all root-node heuristic values, including exact_lp, for the selected shot.", + ) + parser.add_argument( + "--skip-decode", + action="store_true", + help="Only report root heuristics; do not run A* search.", + ) + parser.add_argument( + "--show-shot-detectors", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the selected shot's active detector IDs (default: enabled).", + ) + parser.add_argument( + "--show-error-indices", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the decoded merged-error indices when decoding succeeds (default: enabled).", + ) + parser.add_argument( + "--verbose-search", + action="store_true", + help="Print per-node search diagnostics.", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + if args.shot >= args.sample_num_shots: + parser.error("--shot must be smaller than --sample-num-shots.") + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + errors = merged_errors_from_dem(dem) if args.merge_errors else list(iter_dem_errors_from_dem(dem)) + + if args.dets is not None: + shot_dets = np.zeros(dem.num_detectors, dtype=bool) + shot_obs = np.zeros(dem.num_observables, dtype=bool) + for token in args.dets.split(): + if token == "shot": + continue + if token.startswith("D") and token[1:].isdigit(): + d_idx = int(token[1:]) + if d_idx < dem.num_detectors: + shot_dets[d_idx] = True + elif token.startswith("L") and token[1:].isdigit(): + l_idx = int(token[1:]) + if l_idx < dem.num_observables: + shot_obs[l_idx] = True + else: + dets, obs = sample_detections_and_observables( + circuit, + num_shots=args.sample_num_shots, + seed=args.seed, + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + ) + shot_dets = dets[args.shot] + shot_obs = obs[args.shot] + + decoder = GreedySingletonHeuristicDecoder( + errors, + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + heuristic=args.heuristic, + respect_blocked_errors_in_heuristic=args.respect_blocked_errors_in_heuristic, + lazy_reinsert_heuristics=args.lazy_reinsert_heuristics, + projection_combine_max_plain=args.projection_combine_max_plain, + verbose_search=args.verbose_search, + ) + + print(f"circuit = {args.circuit}") + print(f"heuristic = {args.heuristic}") + print(f"mode = {decoder.mode_name}") + print(f"sample_num_shots = {args.sample_num_shots}") + print(f"shot = {args.shot}") + print(f"num_errors = {decoder.num_errors}") + print(f"num_detectors = {decoder.num_detectors}") + print(f"num_observables = {decoder.num_observables}") + print(f"det_beam = {args.det_beam}") + print(f"merge_errors = {args.merge_errors}") + print(f"respect_blocked_errors_in_heuristic = {args.respect_blocked_errors_in_heuristic}") + print(f"lazy_reinsert_heuristics = {args.lazy_reinsert_heuristics}") + print(f"projection_combine_max_plain = {args.projection_combine_max_plain}") + + if args.show_shot_detectors: + active_dets = np.flatnonzero(shot_dets) + print("shot_detectors =", " ".join(f"D{d}" for d in active_dets)) + + if args.report_all_root_heuristics: + root_errs = np.zeros(decoder.num_errors, dtype=bool) + root_blocked = np.zeros(decoder.num_errors, dtype=bool) + report = decoder.report_root_heuristics(shot_dets, root_errs, root_blocked) + exact = next((v for k, v in report if k == "exact_lp"), None) + print("root_heuristics:") + for name, value in report: + if exact is not None and not math.isinf(exact) and exact > 0: + ratio = value / exact if not math.isinf(value) else INF + ratio_text = "INF" if math.isinf(ratio) else f"{ratio:.6f}" + else: + ratio_text = "n/a" + value_text = "INF" if math.isinf(value) else f"{value:.12f}" + print(f" {name:>12s} value={value_text} ratio_to_exact={ratio_text}") + + if args.skip_decode: + return 0 + + result = decoder.decode(shot_dets, det_beam=args.det_beam) + print(f"success = {result.success}") + print(f"nodes_pushed = {result.nodes_pushed}") + print(f"nodes_popped = {result.nodes_popped}") + print(f"max_queue_size = {result.max_queue_size}") + print(f"heuristic_calls = {result.heuristic_calls}") + print(f"plain_heuristic_calls = {result.plain_heuristic_calls}") + print(f"projection_heuristic_calls = {result.projection_heuristic_calls}") + print(f"refinement_calls = {result.refinement_calls}") + print(f"lp_calls = {result.lp_calls}") + print(f"reinserts = {result.reinserts}") + print(f"projected_nodes_generated = {result.projected_nodes_generated}") + print(f"projected_nodes_refined = {result.projected_nodes_refined}") + print(f"projected_nodes_unrefined_at_finish = {result.projected_nodes_unrefined_at_finish}") + print(f"total_refinement_gain = {result.total_refinement_gain:.6f}") + print(f"max_refinement_gain = {result.max_refinement_gain:.6f}") + print(f"elapsed_seconds = {result.elapsed_seconds:.6f}") + + if not result.success: + print("decode failed") + return 1 + + if args.show_error_indices: + print("decoded_error_indices =", " ".join(map(str, np.flatnonzero(result.errs).tolist()))) + + reproduced_dets = decoder.detectors_from_errs(result.errs) + if not np.array_equal(reproduced_dets, shot_dets): + raise AssertionError("Decoded errors do not reproduce the sampled detection events.") + + decoded_cost = decoder.cost_from_errs(result.errs) + predicted_obs = decoder.observables_from_errs(result.errs) + sampled_obs = np.flatnonzero(shot_obs) + + print(f"num_decoded_errors = {int(result.errs.sum())}") + print(f"decoded_cost = {decoded_cost:.12f}") + print("predicted_observables =", " ".join(f"L{o}" for o in predicted_obs.tolist())) + print("sampled_observables =", " ".join(f"L{o}" for o in sampled_obs.tolist())) + print(f"observables_match = {bool(np.array_equal(predicted_obs, sampled_obs))}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/astar_prototype_singleton_greedy_heuristics_plus_inactive_lift_lazy.py b/src/py/astar/astar_prototype_singleton_greedy_heuristics_plus_inactive_lift_lazy.py new file mode 100644 index 0000000..9674b48 --- /dev/null +++ b/src/py/astar/astar_prototype_singleton_greedy_heuristics_plus_inactive_lift_lazy.py @@ -0,0 +1,1207 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder for Stim circuits using greedy singleton-budget heuristics. + +This version keeps the same Stim-facing API as the earlier greedy prototype but +adds lazy reinsertion / parent-y projection, in the same spirit as the lazy +optimal-singleton prototype: + + * nodes are seeded with a cheap feasible lower bound; + * when a node is popped, the selected heuristic is evaluated on that node; + * if the refined heuristic raises the node key, the node is reinserted; + * expanded nodes project their current feasible y-prices onto children; + * optionally, the projected child bound is maxed with plain detcost. + +Supported heuristic choices: + plain original detector-wise feasible point + asc_deg zero-start saturation ordered by ascending detector degree + desc_plain zero-start saturation ordered by descending plain y_d + plain_sweep start from plain, then one descending saturation sweep + best_of_two max(plain_sweep, asc_deg) + best_of_three max(plain_sweep, asc_deg, desc_plain) + exact_lp exact optimal singleton LP lower bound + lifted_sweep 1-pass inactive bounds transferred to plain_sweep + lifted_exact_lp 1-pass inactive bounds transferred to exact_lp + +When --lazy-reinsert-heuristics is enabled (the default), the root is seeded by +plain detcost and only popped nodes are refined with the selected heuristic. +This works for all of the above heuristics because each returns a feasible +singleton-budget vector y, and projecting that y to a child by keeping prices +on detectors that remain active and zeroing newly active detectors is still a +feasible child singleton-budget point. +""" + +from __future__ import annotations + +import argparse +import heapq +import math +import time +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +import stim +from scipy.optimize import linprog +from scipy.sparse import csr_matrix + +INF = float("inf") +HEURISTIC_EPS = 1e-9 + + +@dataclass(frozen=True) +class ErrorRecord: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class SupportData: + active_detectors: List[int] + supports: List[Tuple[Tuple[int, ...], float]] + incident: Dict[int, List[int]] + + +@dataclass +class SearchState: + errs: np.ndarray + blocked_errs: np.ndarray + dets: np.ndarray + det_counts: np.ndarray + g_cost: float + h_cost: float + h_source: str + refined: bool + y_prices: Optional[np.ndarray] + + +@dataclass +class DecodeResult: + success: bool + errs: np.ndarray + residual_dets: np.ndarray + cost: float + nodes_pushed: int + nodes_popped: int + max_queue_size: int + heuristic_calls: int + plain_heuristic_calls: int + projection_heuristic_calls: int + refinement_calls: int + lp_calls: int + reinserts: int + projected_nodes_generated: int + projected_nodes_refined: int + projected_nodes_unrefined_at_finish: int + total_refinement_gain: float + max_refinement_gain: float + elapsed_seconds: float + + +class UnionFind: + def __init__(self, n: int) -> None: + self.parent = list(range(n)) + self.rank = [0] * n + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra = self.find(a) + rb = self.find(b) + if ra == rb: + return + if self.rank[ra] < self.rank[rb]: + self.parent[ra] = rb + elif self.rank[ra] > self.rank[rb]: + self.parent[rb] = ra + else: + self.parent[rb] = ra + self.rank[ra] += 1 + + +def xor_probability(p0: float, p1: float) -> float: + return p0 * (1.0 - p1) + (1.0 - p0) * p1 + + +def iter_dem_errors_from_dem(dem: stim.DetectorErrorModel) -> Iterable[ErrorRecord]: + for instruction in dem.flattened(): + if instruction.type != "error": + continue + probability = float(instruction.args_copy()[0]) + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + f"Expected flattened error probabilities in (0, 0.5), got {probability}." + ) + + detectors: set[int] = set() + observables: set[int] = set() + for target in instruction.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected DEM target: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + + yield ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=tuple(sorted(detectors)), + observables=tuple(sorted(observables)), + ) + + +def merged_errors_from_dem(dem: stim.DetectorErrorModel) -> List[ErrorRecord]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + for error in iter_dem_errors_from_dem(dem): + key = (error.detectors, error.observables) + p_old = errors_by_symptom.get(key) + if p_old is None: + p_new = error.probability + else: + p_new = xor_probability(p_old, error.probability) + errors_by_symptom[key] = p_new + + merged: List[ErrorRecord] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + f"Merged error has probability >= 0.5 ({probability}); cannot assign positive cost." + ) + merged.append( + ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +class GreedySingletonHeuristicDecoder: + def __init__( + self, + errors: Sequence[ErrorRecord], + num_detectors: int, + num_observables: int, + *, + heuristic: str = "best_of_two", + respect_blocked_errors_in_heuristic: bool = True, + lazy_reinsert_heuristics: bool = True, + projection_combine_max_plain: bool = True, + verbose_search: bool = False, + ) -> None: + self.errors = list(errors) + self.num_errors = len(self.errors) + self.num_detectors = int(num_detectors) + self.num_observables = int(num_observables) + self.heuristic_name = heuristic + self.respect_blocked_errors_in_heuristic = respect_blocked_errors_in_heuristic + self.lazy_reinsert_heuristics = lazy_reinsert_heuristics + self.projection_combine_max_plain = projection_combine_max_plain + self.verbose_search = verbose_search + + self.probabilities = np.array([err.probability for err in self.errors], dtype=np.float64) + self.weights = np.array([err.likelihood_cost for err in self.errors], dtype=np.float64) + self.error_detectors: List[Tuple[int, ...]] = [tuple(err.detectors) for err in self.errors] + self.error_observables: List[Tuple[int, ...]] = [tuple(err.observables) for err in self.errors] + + d2e_lists: List[List[int]] = [[] for _ in range(self.num_detectors)] + for ei, dets in enumerate(self.error_detectors): + for d in dets: + d2e_lists[d].append(ei) + self.d2e: List[np.ndarray] = [np.array(v, dtype=np.int32) for v in d2e_lists] + + self.reset_stats() + + def reset_stats(self) -> None: + self.heuristic_calls = 0 + self.plain_heuristic_calls = 0 + self.projection_heuristic_calls = 0 + self.refinement_calls = 0 + self.lp_calls = 0 + self.reinserts = 0 + self.projected_nodes_generated = 0 + self.projected_nodes_refined = 0 + self.total_refinement_gain = 0.0 + self.max_refinement_gain = 0.0 + + @property + def mode_name(self) -> str: + if self.heuristic_name == "plain": + return "plain" + if self.lazy_reinsert_heuristics: + suffix = "-lazy-projection" + if self.projection_combine_max_plain: + suffix += "-maxplain" + return f"{self.heuristic_name}{suffix}" + return self.heuristic_name + + def _available_errors(self, errs: np.ndarray, blocked_errs: np.ndarray) -> np.ndarray: + available = ~errs + if self.respect_blocked_errors_in_heuristic: + available &= ~blocked_errs + return available + + def _has_cover_for_all_active_detectors(self, dets: np.ndarray, available_errs: np.ndarray) -> bool: + for d in np.flatnonzero(dets): + found = False + for ei in self.d2e[int(d)]: + if available_errs[int(ei)]: + found = True + break + if not found: + return False + return True + + def _apply_inactive_lift(self, y: np.ndarray, active_dets: np.ndarray, available_errs: np.ndarray) -> float: + """Modifies y IN PLACE to increase values using a post-processing dual slack transfer.""" + slacks = np.zeros(self.num_errors, dtype=np.float64) + available_indices = np.flatnonzero(available_errs) + + # Calculate initial slacks + for ei in available_indices: + ei = int(ei) + slacks[ei] = self.weights[ei] + for d in self.error_detectors[ei]: + if active_dets[d]: + slacks[ei] -= y[d] + + active_list = np.flatnonzero(active_dets) + # Sort descending to attack the heaviest y-values first + order = sorted((int(d) for d in active_list), key=lambda d: y[d], reverse=True) + + for d in order: + incident_eis = [int(ei) for ei in self.d2e[d] if available_errs[ei]] + if not incident_eis: + continue + + min_s = min(slacks[ei] for ei in incident_eis) + + # If the detector is bottled-necked, try to transfer slack from inactive neighbors + if min_s < 1e-9: + blocking_eis = [ei for ei in incident_eis if slacks[ei] < 1e-9] + for ei in blocking_eis: + for d_inact in self.error_detectors[ei]: + if not active_dets[d_inact]: + siblings = [int(j) for j in self.d2e[d_inact] if available_errs[j] and j != ei] + if not siblings: + continue + delta = min(slacks[j] for j in siblings) + if delta > 1e-9: + # Execute the transfer + slacks[ei] += delta + for j in siblings: + slacks[j] -= delta + break + + # Re-evaluate the bottleneck and lift if space was created + new_min_s = min(slacks[ei] for ei in incident_eis) + if new_min_s > 1e-9: + y[d] += new_min_s + for ei in incident_eis: + slacks[ei] -= new_min_s + + return float(y[active_dets].sum()) + + def build_support_data(self, active_dets: np.ndarray, available_errs: np.ndarray) -> SupportData: + active_list = sorted(map(int, np.flatnonzero(active_dets))) + incident: Dict[int, List[int]] = {d: [] for d in active_list} + support_to_weight: Dict[Tuple[int, ...], float] = {} + + for ei in np.flatnonzero(available_errs): + ei = int(ei) + support = tuple(d for d in self.error_detectors[ei] if active_dets[d]) + if not support: + continue + weight = float(self.weights[ei]) + old = support_to_weight.get(support) + if old is None or weight < old: + support_to_weight[support] = weight + + supports = list(support_to_weight.items()) + for i, (support, _weight) in enumerate(supports): + for d in support: + if d in incident: + incident[d].append(i) + + return SupportData(active_detectors=active_list, supports=supports, incident=incident) + + def _check_coverage(self, support_data: SupportData) -> bool: + return all(len(support_data.incident[d]) > 0 for d in support_data.active_detectors) + + def plain_detcost_from_counts( + self, + dets: np.ndarray, + available_errs: np.ndarray, + det_counts: np.ndarray, + ) -> Tuple[float, Optional[np.ndarray]]: + self.heuristic_calls += 1 + self.plain_heuristic_calls += 1 + active = np.flatnonzero(dets) + if active.size == 0: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + + y = np.zeros(self.num_detectors, dtype=np.float64) + total = 0.0 + for d in active: + best = INF + for ei in self.d2e[int(d)]: + ei = int(ei) + if not available_errs[ei]: + continue + count = int(det_counts[ei]) + assert count > 0 + value = self.weights[ei] / count + if value < best: + best = value + if math.isinf(best): + return INF, None + y[int(d)] = best + total += best + return total, y + + def heuristic_plain(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + if not support_data.active_detectors: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + y = np.zeros(self.num_detectors, dtype=np.float64) + for d in support_data.active_detectors: + best = INF + for i in support_data.incident[d]: + support, weight = support_data.supports[i] + best = min(best, weight / len(support)) + y[d] = best + return float(y[support_data.active_detectors].sum()), y + + def heuristic_saturation_zero(self, support_data: SupportData, *, order_kind: str) -> Tuple[float, Optional[np.ndarray]]: + if not support_data.active_detectors: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + + slack = np.array([weight for _support, weight in support_data.supports], dtype=np.float64) + y = np.zeros(self.num_detectors, dtype=np.float64) + + if order_kind == "asc_deg": + order = sorted(support_data.active_detectors, key=lambda d: (len(support_data.incident[d]), d)) + elif order_kind == "desc_plain": + _plain_value, y_plain = self.heuristic_plain(support_data) + if y_plain is None: + return INF, None + order = sorted(support_data.active_detectors, key=lambda d: (y_plain[d], d), reverse=True) + else: + raise ValueError(f"Unknown order_kind={order_kind!r}") + + for d in order: + value = min(slack[i] for i in support_data.incident[d]) + if value < 0: + value = 0.0 + y[d] = value + for i in support_data.incident[d]: + slack[i] -= value + return float(y[support_data.active_detectors].sum()), y + + def heuristic_plain_sweep(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + plain_value, y = self.heuristic_plain(support_data) + if y is None: + return INF, None + order = sorted(support_data.active_detectors, key=lambda d: (y[d], d), reverse=True) + for d in order: + max_feasible = min( + weight - sum(y[dd] for dd in support if dd != d) + for support, weight in support_data.supports + if d in support + ) + if max_feasible > y[d]: + y[d] = max_feasible + return float(y[support_data.active_detectors].sum()), y + + def heuristic_exact_lp(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + active = support_data.active_detectors + if not active: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + + detector_index = {d: i for i, d in enumerate(active)} + uf = UnionFind(len(active)) + for support, _weight in support_data.supports: + if len(support) > 1: + a = detector_index[support[0]] + for d in support[1:]: + uf.union(a, detector_index[d]) + + components: Dict[int, List[int]] = defaultdict(list) + for d in active: + components[uf.find(detector_index[d])].append(d) + + y = np.zeros(self.num_detectors, dtype=np.float64) + total = 0.0 + for component in components.values(): + component_set = set(component) + local = {d: i for i, d in enumerate(sorted(component))} + component_supports: List[Tuple[Tuple[int, ...], float]] = [] + for support, weight in support_data.supports: + if support[0] in component_set: + component_supports.append((tuple(local[d] for d in support), weight)) + + rows: List[int] = [] + cols: List[int] = [] + data: List[float] = [] + rhs: List[float] = [] + for r, (support, weight) in enumerate(component_supports): + rhs.append(weight) + for c in support: + rows.append(r) + cols.append(c) + data.append(1.0) + + a_ub = csr_matrix( + (data, (rows, cols)), + shape=(len(component_supports), len(component)), + dtype=np.float64, + ) + self.lp_calls += 1 + result = linprog( + c=-np.ones(len(component), dtype=np.float64), + A_ub=a_ub, + b_ub=np.array(rhs, dtype=np.float64), + bounds=[(0.0, None)] * len(component), + method="highs", + ) + if not result.success: + return INF, None + total += -float(result.fun) + for d, value in zip(sorted(component), result.x): + y[d] = float(value) + return float(total), y + + def evaluate_named_heuristic(self, support_data: SupportData, name: str) -> Tuple[float, Optional[np.ndarray]]: + if name == "plain": + return self.heuristic_plain(support_data) + if name == "asc_deg": + return self.heuristic_saturation_zero(support_data, order_kind="asc_deg") + if name == "desc_plain": + return self.heuristic_saturation_zero(support_data, order_kind="desc_plain") + if name == "plain_sweep": + return self.heuristic_plain_sweep(support_data) + if name == "best_of_two": + v1, y1 = self.heuristic_plain_sweep(support_data) + v2, y2 = self.heuristic_saturation_zero(support_data, order_kind="asc_deg") + if v1 >= v2: + return v1, y1 + return v2, y2 + if name == "best_of_three": + candidates = [ + self.heuristic_plain_sweep(support_data), + self.heuristic_saturation_zero(support_data, order_kind="asc_deg"), + self.heuristic_saturation_zero(support_data, order_kind="desc_plain"), + ] + return max(candidates, key=lambda t: t[0]) + if name == "exact_lp": + return self.heuristic_exact_lp(support_data) + raise ValueError(f"Unknown heuristic {name!r}") + + def compute_support_based_heuristic( + self, + dets: np.ndarray, + errs: np.ndarray, + blocked_errs: np.ndarray, + *, + name: Optional[str] = None, + ) -> Tuple[float, Optional[np.ndarray]]: + self.heuristic_calls += 1 + available = self._available_errors(errs, blocked_errs) + h_name = name or self.heuristic_name + + support_data = self.build_support_data(dets, available) + + if h_name in {"lifted_sweep", "lifted_exact_lp"}: + if h_name == "lifted_sweep": + safe_val, safe_y = self.heuristic_plain_sweep(support_data) + else: + safe_val, safe_y = self.heuristic_exact_lp(support_data) + + if safe_y is None: + return INF, None + + lifted_val = self._apply_inactive_lift(safe_y.copy(), dets, available) + # Return max to guarantee it is >= base heuristic, and safe_y to keep projection valid + return max(lifted_val, safe_val), safe_y + + return self.evaluate_named_heuristic(support_data, h_name) + + def project_child_y( + self, + parent_state: SearchState, + child_dets: np.ndarray, + child_errs: np.ndarray, + child_blocked_errs: np.ndarray, + child_det_counts: np.ndarray, + flipped_detectors: Sequence[int], + ) -> Tuple[float, Optional[np.ndarray], str]: + if parent_state.y_prices is None: + raise AssertionError("Expected a stored feasible y vector before projecting to a child.") + + self.heuristic_calls += 1 + self.projection_heuristic_calls += 1 + available = self._available_errors(child_errs, child_blocked_errs) + if not self._has_cover_for_all_active_detectors(child_dets, available): + return INF, None, "projected" + + y_projected = np.zeros(self.num_detectors, dtype=np.float64) + keep = parent_state.dets & child_dets + y_projected[keep] = parent_state.y_prices[keep] + projected_value = float(y_projected[np.flatnonzero(child_dets)].sum()) + best_value = projected_value + best_y = y_projected + best_source = "projected" + + if self.projection_combine_max_plain: + plain_value, plain_y = self.plain_detcost_from_counts(child_dets, available, child_det_counts) + if plain_y is None: + return INF, None, "plain" + if plain_value > best_value + HEURISTIC_EPS: + best_value = plain_value + best_y = plain_y + best_source = "plain" + + return best_value, best_y, best_source + + def report_root_heuristics(self, dets: np.ndarray, errs: np.ndarray, blocked_errs: np.ndarray) -> List[Tuple[str, float]]: + available = self._available_errors(errs, blocked_errs) + support_data = self.build_support_data(dets, available) + names = ["plain", "asc_deg", "desc_plain", "plain_sweep", "best_of_two", "best_of_three", "exact_lp", "lifted_sweep", "lifted_exact_lp"] + out: List[Tuple[str, float]] = [] + saved_lp_calls = self.lp_calls + + for name in names: + if name in {"lifted_sweep", "lifted_exact_lp"}: + if name == "lifted_sweep": + safe_val, safe_y = self.heuristic_plain_sweep(support_data) + else: + safe_val, safe_y = self.heuristic_exact_lp(support_data) + if safe_y is None: + out.append((name, INF)) + else: + lifted_val = self._apply_inactive_lift(safe_y.copy(), dets, available) + out.append((name, max(lifted_val, safe_val))) + else: + value, _ = self.evaluate_named_heuristic(support_data, name) + out.append((name, value)) + + self.lp_calls = saved_lp_calls + return out + + def _maybe_refine_node(self, state: SearchState) -> Tuple[SearchState, bool]: + if state.refined or self.heuristic_name == "plain" or not self.lazy_reinsert_heuristics: + return state, False + + previous_source = state.h_source + self.refinement_calls += 1 + new_value, new_y = self.compute_support_based_heuristic( + state.dets, + state.errs, + state.blocked_errs, + name=self.heuristic_name, + ) + if new_y is None: + if previous_source == "projected": + self.projected_nodes_refined += 1 + if self.verbose_search: + print( + f" refine approx_h={state.h_cost:.6f} new_h=INF delta=INF reinserted=False discarded=True" + ) + state.refined = True + return state, True + + delta = new_value - state.h_cost + self.total_refinement_gain += max(0.0, delta) + self.max_refinement_gain = max(self.max_refinement_gain, max(0.0, delta)) + + if self.heuristic_name in {"exact_lp", "lifted_exact_lp"} and new_value + 1e-7 < state.h_cost: + raise AssertionError( + f"Exact LP value {new_value} below stored projected value {state.h_cost}." + ) + + if new_value > state.h_cost + HEURISTIC_EPS: + if previous_source == "projected": + self.projected_nodes_refined += 1 + state.h_cost = new_value + state.h_source = "refined" + state.y_prices = new_y + state.refined = True + self.reinserts += 1 + if self.verbose_search: + print( + f" refine approx_h={state.h_cost - delta:.6f} new_h={new_value:.6f} delta={delta:.6f} reinserted=True discarded=False" + ) + return state, True + + if previous_source == "projected": + self.projected_nodes_refined += 1 + if abs(new_value - state.h_cost) <= HEURISTIC_EPS: + state.y_prices = new_y + state.refined = True + if self.verbose_search: + new_text = "INF" if math.isinf(new_value) else f"{new_value:.6f}" + print( + f" refine approx_h={state.h_cost:.6f} new_h={new_text} delta={delta:.6f} reinserted=False discarded=False" + ) + return state, False + + def decode(self, shot_dets: np.ndarray, det_beam: float = INF) -> DecodeResult: + start_time = time.perf_counter() + self.reset_stats() + + dets0 = np.array(shot_dets, dtype=bool, copy=True) + errs0 = np.zeros(self.num_errors, dtype=bool) + blocked0 = np.zeros(self.num_errors, dtype=bool) + det_counts0 = np.zeros(self.num_errors, dtype=np.uint16) + for d in np.flatnonzero(dets0): + for ei in self.d2e[int(d)]: + det_counts0[int(ei)] += 1 + + root_h, root_y = self.plain_detcost_from_counts(dets0, self._available_errors(errs0, blocked0), det_counts0) + if root_y is None or math.isinf(root_h): + return DecodeResult( + success=False, + errs=errs0, + residual_dets=dets0, + cost=INF, + nodes_pushed=1, + nodes_popped=0, + max_queue_size=1, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + root_refined = (self.heuristic_name == "plain") or (not self.lazy_reinsert_heuristics) + if root_refined and self.heuristic_name != "plain": + eager_h, eager_y = self.compute_support_based_heuristic(dets0, errs0, blocked0, name=self.heuristic_name) + if eager_y is None or math.isinf(eager_h): + return DecodeResult( + success=False, + errs=errs0, + residual_dets=dets0, + cost=INF, + nodes_pushed=1, + nodes_popped=0, + max_queue_size=1, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + root_h, root_y = eager_h, eager_y + + root_state = SearchState( + errs=errs0, + blocked_errs=blocked0, + dets=dets0, + det_counts=det_counts0, + g_cost=0.0, + h_cost=root_h, + h_source="plain" if not root_refined else ("plain" if self.heuristic_name == "plain" else "refined"), + refined=root_refined, + y_prices=root_y, + ) + + heap: List[Tuple[float, int, int, SearchState]] = [] + counter = 0 + heapq.heappush(heap, (root_state.g_cost + root_state.h_cost, int(dets0.sum()), counter, root_state)) + counter += 1 + nodes_pushed = 1 + nodes_popped = 0 + max_queue_size = 1 + min_num_dets = int(dets0.sum()) + + while heap: + max_queue_size = max(max_queue_size, len(heap)) + f_cost, num_dets, _entry_id, state = heapq.heappop(heap) + nodes_popped += 1 + max_num_dets = min_num_dets + det_beam + if num_dets > max_num_dets: + continue + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = min_num_dets + det_beam + + if self.verbose_search: + projected_unrefined = self.projected_nodes_generated - self.projected_nodes_refined + print( + f"len(heap)={len(heap)} nodes_pushed={nodes_pushed} nodes_popped={nodes_popped} " + f"lp_calls={self.lp_calls} reinserts={self.reinserts} proj_generated={self.projected_nodes_generated} " + f"proj_refined={self.projected_nodes_refined} proj_unrefined_so_far={projected_unrefined} " + f"num_dets={num_dets} max_num_dets={max_num_dets} f={f_cost:.6f} g={state.g_cost:.6f} " + f"h={state.h_cost:.6f} h_source={state.h_source} refined={state.refined}" + ) + + if num_dets == 0: + return DecodeResult( + success=True, + errs=state.errs, + residual_dets=state.dets, + cost=state.g_cost, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + max_queue_size=max_queue_size, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + state, should_reinsert = self._maybe_refine_node(state) + if should_reinsert: + if state.y_prices is None or math.isinf(state.h_cost): + if state.h_source == "projected": + self.projected_nodes_refined += 1 + continue + if state.h_source != "plain": + heapq.heappush(heap, (state.g_cost + state.h_cost, num_dets, counter, state)) + counter += 1 + continue + + min_det = int(np.flatnonzero(state.dets)[0]) + prefix_blocked = state.blocked_errs.copy() + children_generated = 0 + children_beam_pruned = 0 + children_infeasible = 0 + children_projected = 0 + + for ei in self.d2e[min_det]: + ei = int(ei) + prefix_blocked[ei] = True + if state.errs[ei] or state.blocked_errs[ei]: + continue + + child_errs = state.errs.copy() + child_errs[ei] = True + child_blocked = prefix_blocked.copy() + child_dets = state.dets.copy() + child_det_counts = state.det_counts.copy() + for d in self.error_detectors[ei]: + d = int(d) + if child_dets[d]: + child_dets[d] = False + for oei in self.d2e[d]: + child_det_counts[int(oei)] -= 1 + else: + child_dets[d] = True + for oei in self.d2e[d]: + child_det_counts[int(oei)] += 1 + + child_num_dets = int(child_dets.sum()) + if child_num_dets > max_num_dets: + children_beam_pruned += 1 + continue + + child_g = state.g_cost + float(self.weights[ei]) + if self.heuristic_name == "plain" or (not self.lazy_reinsert_heuristics): + child_h, child_y = self.compute_support_based_heuristic( + child_dets, child_errs, child_blocked, name=self.heuristic_name + ) + child_source = "plain" if self.heuristic_name == "plain" else "refined" + child_refined = True + else: + if state.y_prices is None: + raise AssertionError("Expected parent feasible y-prices before projecting to child.") + child_h, child_y, child_source = self.project_child_y( + state, + child_dets, + child_errs, + child_blocked, + child_det_counts, + self.error_detectors[ei], + ) + self.projected_nodes_generated += 1 + children_projected += 1 + child_refined = False + + if child_y is None or math.isinf(child_h): + children_infeasible += 1 + continue + + child_state = SearchState( + errs=child_errs, + blocked_errs=child_blocked, + dets=child_dets, + det_counts=child_det_counts, + g_cost=child_g, + h_cost=child_h, + h_source=child_source, + refined=child_refined, + y_prices=child_y, + ) + heapq.heappush(heap, (child_g + child_h, child_num_dets, counter, child_state)) + counter += 1 + nodes_pushed += 1 + children_generated += 1 + + if self.verbose_search: + projected_unrefined = self.projected_nodes_generated - self.projected_nodes_refined + print( + f" expanded children_generated={children_generated} children_projected={children_projected} " + f"beam_pruned={children_beam_pruned} infeasible={children_infeasible} " + f"lp_calls={self.lp_calls} proj_unrefined_so_far={projected_unrefined}" + ) + + return DecodeResult( + success=False, + errs=np.zeros(self.num_errors, dtype=bool), + residual_dets=np.array(shot_dets, dtype=bool, copy=True), + cost=INF, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + max_queue_size=max_queue_size, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + def cost_from_errs(self, errs: np.ndarray) -> float: + return float(self.weights[errs].sum()) + + def detectors_from_errs(self, errs: np.ndarray) -> np.ndarray: + dets = np.zeros(self.num_detectors, dtype=bool) + for ei in np.flatnonzero(errs): + for d in self.error_detectors[int(ei)]: + dets[d] ^= True + return dets + + def observables_from_errs(self, errs: np.ndarray) -> np.ndarray: + parity: Dict[int, bool] = {} + for ei in np.flatnonzero(errs): + for obs in self.error_observables[int(ei)]: + parity[int(obs)] = not parity.get(int(obs), False) + return np.array(sorted(obs for obs, bit in parity.items() if bit), dtype=np.int32) + + +def sample_detections_and_observables( + circuit: stim.Circuit, + *, + num_shots: int, + seed: int, + num_detectors: int, + num_observables: int, +) -> Tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets_packed, obs_packed = sampler.sample( + shots=num_shots, + separate_observables=True, + bit_packed=True, + ) + dets_unpacked = np.unpackbits( + dets_packed, + bitorder="little", + axis=1, + count=num_detectors, + ) + obs_unpacked = np.unpackbits( + obs_packed, + bitorder="little", + axis=1, + count=num_observables, + ) + return dets_unpacked.astype(bool), obs_unpacked.astype(bool) + + +def parse_det_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "infinity", "none"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("det-beam must be non-negative or 'inf'.") + return float(value) + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Prototype A* decoder for Stim circuits using greedy singleton-budget heuristics." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a .stim circuit file.") + parser.add_argument( + "--dets", + type=str, + default=None, + help="String of shot dets (e.g., 'shot D0 D1 L2') to parse instead of sampling.", + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=100, + help="Number of shots to sample from Stim before selecting --shot (default: 100).", + ) + parser.add_argument( + "--shot", + type=int, + default=0, + help="Index of the sampled shot to decode (default: 0).", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Stim sampler seed (default: 27123839530).", + ) + parser.add_argument( + "--det-beam", + type=parse_det_beam, + default=INF, + help="Beam cutoff on the residual detector count; use 'inf' for none.", + ) + parser.add_argument( + "--heuristic", + choices=["plain", "asc_deg", "desc_plain", "plain_sweep", "best_of_two", "best_of_three", "exact_lp", "lifted_sweep", "lifted_exact_lp"], + default="best_of_two", + help="Lower-bound heuristic to use during A* search (default: best_of_two).", + ) + parser.add_argument( + "--lazy-reinsert-heuristics", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "For non-plain heuristics, seed nodes with plain detcost, refine on pop, and reinsert when the selected " + "heuristic improves the key (default: enabled)." + ), + ) + parser.add_argument( + "--projection-combine-max-plain", + action=argparse.BooleanOptionalAction, + default=True, + help="When projecting parent y-prices to a child, take max(projected, plain detcost) (default: enabled).", + ) + parser.add_argument( + "--merge-errors", + action=argparse.BooleanOptionalAction, + default=True, + help="Merge indistinguishable DEM errors before decoding (default: enabled).", + ) + parser.add_argument( + "--respect-blocked-errors-in-heuristic", + action=argparse.BooleanOptionalAction, + default=True, + help="Exclude precedence-blocked errors when forming the lower bound (default: enabled).", + ) + parser.add_argument( + "--report-all-root-heuristics", + action="store_true", + help="Print all root-node heuristic values, including exact_lp, for the selected shot.", + ) + parser.add_argument( + "--skip-decode", + action="store_true", + help="Only report root heuristics; do not run A* search.", + ) + parser.add_argument( + "--show-shot-detectors", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the selected shot's active detector IDs (default: enabled).", + ) + parser.add_argument( + "--show-error-indices", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the decoded merged-error indices when decoding succeeds (default: enabled).", + ) + parser.add_argument( + "--verbose-search", + action="store_true", + help="Print per-node search diagnostics.", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + if args.shot >= args.sample_num_shots: + parser.error("--shot must be smaller than --sample-num-shots.") + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + errors = merged_errors_from_dem(dem) if args.merge_errors else list(iter_dem_errors_from_dem(dem)) + + if args.dets is not None: + shot_dets = np.zeros(dem.num_detectors, dtype=bool) + shot_obs = np.zeros(dem.num_observables, dtype=bool) + for token in args.dets.split(): + if token == "shot": + continue + if token.startswith("D") and token[1:].isdigit(): + d_idx = int(token[1:]) + if d_idx < dem.num_detectors: + shot_dets[d_idx] = True + elif token.startswith("L") and token[1:].isdigit(): + l_idx = int(token[1:]) + if l_idx < dem.num_observables: + shot_obs[l_idx] = True + else: + dets, obs = sample_detections_and_observables( + circuit, + num_shots=args.sample_num_shots, + seed=args.seed, + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + ) + shot_dets = dets[args.shot] + shot_obs = obs[args.shot] + + decoder = GreedySingletonHeuristicDecoder( + errors, + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + heuristic=args.heuristic, + respect_blocked_errors_in_heuristic=args.respect_blocked_errors_in_heuristic, + lazy_reinsert_heuristics=args.lazy_reinsert_heuristics, + projection_combine_max_plain=args.projection_combine_max_plain, + verbose_search=args.verbose_search, + ) + + print(f"circuit = {args.circuit}") + print(f"heuristic = {args.heuristic}") + print(f"mode = {decoder.mode_name}") + print(f"sample_num_shots = {args.sample_num_shots}") + print(f"shot = {args.shot}") + print(f"num_errors = {decoder.num_errors}") + print(f"num_detectors = {decoder.num_detectors}") + print(f"num_observables = {decoder.num_observables}") + print(f"det_beam = {args.det_beam}") + print(f"merge_errors = {args.merge_errors}") + print(f"respect_blocked_errors_in_heuristic = {args.respect_blocked_errors_in_heuristic}") + print(f"lazy_reinsert_heuristics = {args.lazy_reinsert_heuristics}") + print(f"projection_combine_max_plain = {args.projection_combine_max_plain}") + + if args.show_shot_detectors: + active_dets = np.flatnonzero(shot_dets) + print("shot_detectors =", " ".join(f"D{d}" for d in active_dets)) + + if args.report_all_root_heuristics: + root_errs = np.zeros(decoder.num_errors, dtype=bool) + root_blocked = np.zeros(decoder.num_errors, dtype=bool) + report = decoder.report_root_heuristics(shot_dets, root_errs, root_blocked) + exact = next((v for k, v in report if k == "exact_lp"), None) + print("root_heuristics:") + for name, value in report: + if exact is not None and not math.isinf(exact) and exact > 0: + ratio = value / exact if not math.isinf(value) else INF + ratio_text = "INF" if math.isinf(ratio) else f"{ratio:.6f}" + else: + ratio_text = "n/a" + value_text = "INF" if math.isinf(value) else f"{value:.12f}" + print(f" {name:>12s} value={value_text} ratio_to_exact={ratio_text}") + + if args.skip_decode: + return 0 + + result = decoder.decode(shot_dets, det_beam=args.det_beam) + print(f"success = {result.success}") + print(f"nodes_pushed = {result.nodes_pushed}") + print(f"nodes_popped = {result.nodes_popped}") + print(f"max_queue_size = {result.max_queue_size}") + print(f"heuristic_calls = {result.heuristic_calls}") + print(f"plain_heuristic_calls = {result.plain_heuristic_calls}") + print(f"projection_heuristic_calls = {result.projection_heuristic_calls}") + print(f"refinement_calls = {result.refinement_calls}") + print(f"lp_calls = {result.lp_calls}") + print(f"reinserts = {result.reinserts}") + print(f"projected_nodes_generated = {result.projected_nodes_generated}") + print(f"projected_nodes_refined = {result.projected_nodes_refined}") + print(f"projected_nodes_unrefined_at_finish = {result.projected_nodes_unrefined_at_finish}") + print(f"total_refinement_gain = {result.total_refinement_gain:.6f}") + print(f"max_refinement_gain = {result.max_refinement_gain:.6f}") + print(f"elapsed_seconds = {result.elapsed_seconds:.6f}") + + if not result.success: + print("decode failed") + return 1 + + if args.show_error_indices: + print("decoded_error_indices =", " ".join(map(str, np.flatnonzero(result.errs).tolist()))) + + reproduced_dets = decoder.detectors_from_errs(result.errs) + if not np.array_equal(reproduced_dets, shot_dets): + raise AssertionError("Decoded errors do not reproduce the sampled detection events.") + + decoded_cost = decoder.cost_from_errs(result.errs) + predicted_obs = decoder.observables_from_errs(result.errs) + sampled_obs = np.flatnonzero(shot_obs) + + print(f"num_decoded_errors = {int(result.errs.sum())}") + print(f"decoded_cost = {decoded_cost:.12f}") + print("predicted_observables =", " ".join(f"L{o}" for o in predicted_obs.tolist())) + print("sampled_observables =", " ".join(f"L{o}" for o in sampled_obs.tolist())) + print(f"observables_match = {bool(np.array_equal(predicted_obs, sampled_obs))}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/astar_prototype_singleton_restricted_lazy.py b/src/py/astar/astar_prototype_singleton_restricted_lazy.py new file mode 100644 index 0000000..3545498 --- /dev/null +++ b/src/py/astar/astar_prototype_singleton_restricted_lazy.py @@ -0,0 +1,1675 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder with lazy optimal-singleton refinement. + +This script is intentionally packaged similarly to astar_prototype_subset_detcost_lazy.py, +but specialized to the singleton LP. It offers three modes: + + --opt-singleton-detcost-mode plain + Use plain detcost only. + + --opt-singleton-detcost-mode full + Lazy exact singleton LP on pop, with projected child lower bounds. + + --opt-singleton-detcost-mode restricted + Lazy exact singleton LP on pop, but solved by a restricted-master / + separation loop seeded from the parent tight set. + +Two "outside the box" ideas are built in: + + 1) Parent-primal projection. + If y_parent is feasible for the parent singleton LP, then setting the child + detector prices to y_parent on detectors that remain active and 0 on newly + active detectors is automatically feasible for the child singleton LP. + That gives a cheap admissible child lower bound. + + 2) Local residual projection LP. + On top of the projected parent prices, we can re-optimize a tiny local LP + on either the newly active detectors or the neighborhood touched by the + changed detector set, while keeping the outside detector prices fixed. + This is still admissible because it is a feasible child primal solution. +""" + +from __future__ import annotations + +import argparse +import heapq +import json +import math +import time +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +import stim +from scipy import sparse +from scipy.optimize import linprog + +INF = math.inf +HEURISTIC_EPS = 1e-9 + + +@dataclass(frozen=True) +class MergedError: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class DecoderData: + num_detectors: int + num_observables: int + errors: List[MergedError] + detector_to_errors: List[List[int]] + error_costs: np.ndarray + error_detectors: List[Tuple[int, ...]] + error_observables: List[Tuple[int, ...]] + + +@dataclass +class SearchState: + activated_errors: Tuple[int, ...] + blocked_errors: np.ndarray + active_detectors: np.ndarray + active_detector_counts: np.ndarray + path_cost: float + heuristic_cost: float + heuristic_source: str + exact_refined: bool + lp_solution: Optional["SingletonLPSolution"] = None + warm_start_solution: Optional["SingletonLPSolution"] = None + changed_detectors_from_parent: Tuple[int, ...] = () + + +@dataclass +class DecodeStats: + num_pq_pushed: int + num_nodes_popped: int + max_queue_size: int + heuristic_calls: int + plain_heuristic_calls: int + projection_heuristic_calls: int + exact_refinement_calls: int + lp_calls: int + lp_reinserts: int + projected_nodes_generated: int + projected_nodes_refined: int + projected_nodes_unrefined_at_finish: int + total_lp_refinement_gain: float + max_lp_refinement_gain: float + lp_total_seconds: float + projection_local_lp_calls: int + projection_local_lp_seconds: float + restricted_total_rounds: int + restricted_total_added_supports: int + restricted_total_fallbacks: int + full_check_calls: int + full_check_max_abs_delta: float + elapsed_seconds: float + heuristic_name: str + + +@dataclass +class DecodeResult: + activated_errors: Tuple[int, ...] + path_cost: float + stats: DecodeStats + + +@dataclass(frozen=True) +class RestrictedMasterConfig: + add_policy: str = "topk" # one | topk | all + add_top_k: int = 3 + violation_tol: float = 1e-9 + tight_tol: float = 1e-8 + prune_slack: bool = True + prune_tol: float = 1e-8 + seed_normalized_global_k: int = 0 + seed_normalized_touching_changed_k: int = 2 + max_rounds: int = 50 + fallback_full: bool = True + full_check_every: int = 0 + + +@dataclass +class SingletonLPSolution: + value: float + active_detectors: Tuple[int, ...] + y_by_detector: Dict[int, float] + tight_supports: Tuple[Tuple[int, ...], ...] + num_components: int + num_variables: int + num_constraints: int + num_selected_constraints: int + num_rounds: int + solve_mode: str + + +@dataclass +class SingletonLPSolverStats: + lp_calls: int = 0 + lp_total_seconds: float = 0.0 + projection_local_lp_calls: int = 0 + projection_local_lp_seconds: float = 0.0 + restricted_total_rounds: int = 0 + restricted_total_added_supports: int = 0 + restricted_total_fallbacks: int = 0 + full_check_calls: int = 0 + full_check_max_abs_delta: float = 0.0 + + +class UnionFind: + def __init__(self, size: int) -> None: + self.parent = list(range(size)) + self.rank = [0] * size + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra = self.find(a) + rb = self.find(b) + if ra == rb: + return + if self.rank[ra] < self.rank[rb]: + self.parent[ra] = rb + elif self.rank[ra] > self.rank[rb]: + self.parent[rb] = ra + else: + self.parent[rb] = ra + self.rank[ra] += 1 + + +class LPLogger: + def __init__(self, path: Path, *, every: int = 1, top_k: int = 12) -> None: + self.path = path + self.every = max(1, every) + self.top_k = max(1, top_k) + self.path.parent.mkdir(parents=True, exist_ok=True) + self.path.write_text("") + + def maybe_log(self, *, call_index: int, payload: Dict[str, Any]) -> None: + if call_index % self.every != 0: + return + with self.path.open("a", encoding="utf-8") as f: + f.write(json.dumps(payload, sort_keys=True) + "\n") + + +class SingletonLPHeuristic: + def __init__( + self, + data: DecoderData, + *, + exact_mode: str, + projection_mode: str, + projection_combine_max_plain: bool, + restricted_config: RestrictedMasterConfig, + logger: Optional[LPLogger] = None, + ) -> None: + if exact_mode not in {"full", "restricted"}: + raise ValueError(f"Unsupported exact_mode: {exact_mode}") + if projection_mode not in {"plain", "parent_y", "new_only", "changed_neighborhood"}: + raise ValueError(f"Unsupported projection_mode: {projection_mode}") + self.data = data + self.exact_mode = exact_mode + self.projection_mode = projection_mode + self.projection_combine_max_plain = projection_combine_max_plain + self.restricted_config = restricted_config + self.logger = logger + self.stats = SingletonLPSolverStats() + self.exact_solve_calls = 0 + + def reset_stats(self) -> None: + self.stats = SingletonLPSolverStats() + self.exact_solve_calls = 0 + + def solve_exact( + self, + active_detectors: np.ndarray, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, + *, + warm_start_solution: Optional[SingletonLPSolution], + changed_detectors: Tuple[int, ...], + ) -> Tuple[SingletonLPSolution, Dict[str, Any]]: + self.exact_solve_calls += 1 + t0 = time.perf_counter() + active_detector_ids = tuple(int(d) for d in np.flatnonzero(active_detectors)) + support_costs = build_active_support_costs( + data=self.data, + active_detectors=active_detectors, + blocked_errors=blocked_errors, + active_detector_counts=active_detector_counts, + ) + + if not active_detector_ids: + elapsed = time.perf_counter() - t0 + payload = { + "solve_mode": self.exact_mode, + "objective": 0.0, + "num_components": 0, + "num_variables": 0, + "num_constraints": 0, + "num_selected_constraints": 0, + "num_rounds": 0, + "num_supports_total": 0, + "solve_seconds": elapsed, + "structurally_infeasible": False, + } + return ( + SingletonLPSolution( + value=0.0, + active_detectors=(), + y_by_detector={}, + tight_supports=(), + num_components=0, + num_variables=0, + num_constraints=0, + num_selected_constraints=0, + num_rounds=0, + solve_mode=self.exact_mode, + ), + payload, + ) + + missing_cover = [ + detector + for detector in active_detector_ids + if not any(detector in support for support in support_costs) + ] + if missing_cover: + elapsed = time.perf_counter() - t0 + payload = { + "solve_mode": self.exact_mode, + "objective": INF, + "num_components": 0, + "num_variables": len(active_detector_ids), + "num_constraints": len(support_costs), + "num_selected_constraints": 0, + "num_rounds": 0, + "num_supports_total": len(support_costs), + "solve_seconds": elapsed, + "structurally_infeasible": True, + "missing_cover_detectors": missing_cover, + } + return ( + SingletonLPSolution( + value=INF, + active_detectors=active_detector_ids, + y_by_detector={}, + tight_supports=(), + num_components=0, + num_variables=len(active_detector_ids), + num_constraints=len(support_costs), + num_selected_constraints=0, + num_rounds=0, + solve_mode=self.exact_mode, + ), + payload, + ) + + if self.exact_mode == "full": + solution, full_payload = self._solve_full_support_lp( + active_detector_ids=active_detector_ids, + support_costs=support_costs, + solve_mode="full", + ) + elapsed = time.perf_counter() - t0 + payload = dict(full_payload) + payload.update( + { + "solve_mode": "full", + "num_supports_total": len(support_costs), + "solve_seconds": elapsed, + "structurally_infeasible": False, + } + ) + return solution, payload + + solution, payload = self._solve_restricted_exact( + active_detector_ids=active_detector_ids, + support_costs=support_costs, + warm_start_solution=warm_start_solution, + changed_detectors=changed_detectors, + ) + payload.update( + { + "solve_mode": "restricted", + "num_supports_total": len(support_costs), + "structurally_infeasible": False, + } + ) + return solution, payload + + def project_to_child( + self, + parent_solution: SingletonLPSolution, + child_active_detectors: np.ndarray, + child_blocked_errors: np.ndarray, + child_active_detector_counts: np.ndarray, + *, + changed_detectors: Tuple[int, ...], + ) -> float: + if self.projection_mode == "plain": + projected = plain_detcost_heuristic( + data=self.data, + active_detectors=child_active_detectors, + blocked_errors=child_blocked_errors, + active_detector_counts=child_active_detector_counts, + ) + return projected + + parent_active_set = set(parent_solution.active_detectors) + child_active_ids = tuple(int(d) for d in np.flatnonzero(child_active_detectors)) + parent_y = parent_solution.y_by_detector + + # Fixed outside prices inherited from the parent exact primal solution. + fixed_outside_y: Dict[int, float] = {} + region_detectors: set[int] = set() + if self.projection_mode == "parent_y": + region_detectors = set() + elif self.projection_mode == "new_only": + region_detectors = {d for d in child_active_ids if d not in parent_active_set} + elif self.projection_mode == "changed_neighborhood": + changed_set = set(changed_detectors) + region_detectors = {d for d in child_active_ids if d not in parent_active_set} + for detector in changed_set: + for error_index in self.data.detector_to_errors[detector]: + if child_blocked_errors[error_index]: + continue + if child_active_detector_counts[error_index] <= 0: + continue + for other_detector in self.data.error_detectors[error_index]: + if child_active_detectors[other_detector]: + region_detectors.add(other_detector) + else: + raise AssertionError("unreachable projection mode") + + for detector in child_active_ids: + if detector in region_detectors: + continue + if detector in parent_active_set: + fixed_outside_y[detector] = parent_y.get(detector, 0.0) + else: + fixed_outside_y[detector] = 0.0 + + projected = sum(fixed_outside_y.values()) + + if region_detectors: + local_gain = self._solve_local_region_projection_lp( + child_active_detectors=child_active_detectors, + child_blocked_errors=child_blocked_errors, + child_active_detector_counts=child_active_detector_counts, + region_detectors=region_detectors, + fixed_outside_y=fixed_outside_y, + ) + if local_gain == INF: + projected = INF + else: + projected += local_gain + + if self.projection_combine_max_plain: + plain = plain_detcost_heuristic( + data=self.data, + active_detectors=child_active_detectors, + blocked_errors=child_blocked_errors, + active_detector_counts=child_active_detector_counts, + ) + projected = max(projected, plain) + return projected + + def _solve_local_region_projection_lp( + self, + *, + child_active_detectors: np.ndarray, + child_blocked_errors: np.ndarray, + child_active_detector_counts: np.ndarray, + region_detectors: set[int], + fixed_outside_y: Dict[int, float], + ) -> float: + if not region_detectors: + return 0.0 + t0 = time.perf_counter() + region_support_costs: Dict[Tuple[int, ...], float] = {} + for error_index, error_detectors in enumerate(self.data.error_detectors): + if child_blocked_errors[error_index]: + continue + count = int(child_active_detector_counts[error_index]) + if count <= 0: + continue + full_support = tuple(d for d in error_detectors if child_active_detectors[d]) + assert len(full_support) == count + local_support = tuple(d for d in full_support if d in region_detectors) + if not local_support: + continue + fixed = sum(fixed_outside_y.get(d, 0.0) for d in full_support if d not in region_detectors) + residual = float(self.data.error_costs[error_index]) - fixed + if residual < -1e-8: + raise AssertionError( + f"Projected parent y is infeasible for child: residual={residual} on error {error_index}." + ) + residual = max(0.0, residual) + previous = region_support_costs.get(local_support) + if previous is None or residual < previous: + region_support_costs[local_support] = residual + + region_detector_ids = tuple(sorted(region_detectors)) + if any(not any(detector in support for support in region_support_costs) for detector in region_detector_ids): + # No admissible gain on uncovered region detectors; keep them at zero. + elapsed = time.perf_counter() - t0 + self.stats.projection_local_lp_calls += 1 + self.stats.projection_local_lp_seconds += elapsed + return 0.0 + + objective, _, _, _, _ = solve_primal_lp_on_supports( + detector_ids=region_detector_ids, + support_costs=region_support_costs, + record_stats=self.stats, + count_as_main_lp=False, + ) + elapsed = time.perf_counter() - t0 + self.stats.projection_local_lp_calls += 1 + self.stats.projection_local_lp_seconds += elapsed + return objective + + def _solve_full_support_lp( + self, + *, + active_detector_ids: Tuple[int, ...], + support_costs: Dict[Tuple[int, ...], float], + solve_mode: str, + ) -> Tuple[SingletonLPSolution, Dict[str, Any]]: + components = split_support_costs_into_components( + active_detector_ids=active_detector_ids, + support_costs=support_costs, + ) + total_value = 0.0 + total_num_variables = 0 + total_num_constraints = 0 + y_by_detector: Dict[int, float] = {} + tight_supports: List[Tuple[int, ...]] = [] + for detector_ids, component_support_costs in components: + value, component_y, component_tight, num_vars, num_constraints = solve_primal_lp_on_supports( + detector_ids=detector_ids, + support_costs=component_support_costs, + record_stats=self.stats, + count_as_main_lp=True, + ) + total_value += value + total_num_variables += num_vars + total_num_constraints += num_constraints + y_by_detector.update(component_y) + tight_supports.extend(component_tight) + + solution = SingletonLPSolution( + value=total_value, + active_detectors=active_detector_ids, + y_by_detector=y_by_detector, + tight_supports=tuple(sorted(set(tight_supports))), + num_components=len(components), + num_variables=total_num_variables, + num_constraints=total_num_constraints, + num_selected_constraints=total_num_constraints, + num_rounds=1, + solve_mode=solve_mode, + ) + payload = { + "objective": total_value, + "num_components": len(components), + "num_variables": total_num_variables, + "num_constraints": total_num_constraints, + "num_selected_constraints": total_num_constraints, + "num_rounds": 1, + "tight_support_count": len(solution.tight_supports), + "top_tight_supports": [ + {"support": list(support), "cost": float(support_costs[support])} + for support in sorted(solution.tight_supports, key=lambda s: (len(s), s))[: self.logger.top_k if self.logger else 12] + ], + } + return solution, payload + + def _solve_restricted_exact( + self, + *, + active_detector_ids: Tuple[int, ...], + support_costs: Dict[Tuple[int, ...], float], + warm_start_solution: Optional[SingletonLPSolution], + changed_detectors: Tuple[int, ...], + ) -> Tuple[SingletonLPSolution, Dict[str, Any]]: + t0 = time.perf_counter() + components = split_support_costs_into_components( + active_detector_ids=active_detector_ids, + support_costs=support_costs, + ) + total_value = 0.0 + total_num_variables = 0 + total_num_constraints = 0 + total_num_selected_constraints = 0 + total_rounds = 0 + y_by_detector: Dict[int, float] = {} + tight_supports: List[Tuple[int, ...]] = [] + component_payloads: List[Dict[str, Any]] = [] + fallbacks_used = 0 + + parent_tight_supports = set() if warm_start_solution is None else set(warm_start_solution.tight_supports) + changed_set = set(changed_detectors) + + for detector_ids, component_support_costs in components: + component_result, component_payload = self._solve_restricted_component( + detector_ids=detector_ids, + support_costs=component_support_costs, + parent_tight_supports=parent_tight_supports, + changed_set=changed_set, + ) + total_value += component_result["value"] + total_num_variables += len(detector_ids) + total_num_constraints += len(component_support_costs) + total_num_selected_constraints += component_result["num_selected_constraints"] + total_rounds += component_result["num_rounds"] + y_by_detector.update(component_result["y_by_detector"]) + tight_supports.extend(component_result["tight_supports"]) + component_payloads.append(component_payload) + if component_result["used_full_fallback"]: + fallbacks_used += 1 + + self.stats.restricted_total_rounds += total_rounds + self.stats.restricted_total_fallbacks += fallbacks_used + + solution = SingletonLPSolution( + value=total_value, + active_detectors=active_detector_ids, + y_by_detector=y_by_detector, + tight_supports=tuple(sorted(set(tight_supports))), + num_components=len(components), + num_variables=total_num_variables, + num_constraints=total_num_constraints, + num_selected_constraints=total_num_selected_constraints, + num_rounds=total_rounds, + solve_mode="restricted", + ) + + if self.restricted_config.full_check_every > 0 and self.exact_solve_calls % self.restricted_config.full_check_every == 0: + self.stats.full_check_calls += 1 + full_solution, _ = self._solve_full_support_lp( + active_detector_ids=active_detector_ids, + support_costs=support_costs, + solve_mode="full_check", + ) + delta = abs(full_solution.value - solution.value) + self.stats.full_check_max_abs_delta = max(self.stats.full_check_max_abs_delta, delta) + if delta > 1e-7: + raise AssertionError( + f"Restricted exact solver mismatch: restricted={solution.value} full={full_solution.value} delta={delta}" + ) + + payload = { + "objective": total_value, + "num_components": len(components), + "num_variables": total_num_variables, + "num_constraints": total_num_constraints, + "num_selected_constraints": total_num_selected_constraints, + "num_rounds": total_rounds, + "tight_support_count": len(solution.tight_supports), + "used_full_fallbacks": fallbacks_used, + "components": component_payloads, + "solve_seconds": time.perf_counter() - t0, + } + return solution, payload + + def _solve_restricted_component( + self, + *, + detector_ids: Tuple[int, ...], + support_costs: Dict[Tuple[int, ...], float], + parent_tight_supports: set[Tuple[int, ...]], + changed_set: set[int], + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + cover_support_for_detector: Dict[int, Tuple[int, ...]] = {} + supports_touching_changed: List[Tuple[int, ...]] = [] + for support, cost in support_costs.items(): + for detector in support: + previous = cover_support_for_detector.get(detector) + if previous is None: + cover_support_for_detector[detector] = support + else: + prev_key = (support_costs[previous], len(previous), previous) + cur_key = (cost, len(support), support) + if cur_key < prev_key: + cover_support_for_detector[detector] = support + if changed_set and any(detector in changed_set for detector in support): + supports_touching_changed.append(support) + + selected_supports: set[Tuple[int, ...]] = set(cover_support_for_detector.values()) + cover_supports = set(selected_supports) + surviving_parent_tight = {support for support in parent_tight_supports if support in support_costs} + selected_supports |= surviving_parent_tight + + if self.restricted_config.seed_normalized_global_k > 0: + cheapest_norm = sorted( + support_costs, + key=lambda support: (support_costs[support] / len(support), support_costs[support], len(support), support), + )[: self.restricted_config.seed_normalized_global_k] + selected_supports.update(cheapest_norm) + + if self.restricted_config.seed_normalized_touching_changed_k > 0 and supports_touching_changed: + touching = sorted( + supports_touching_changed, + key=lambda support: (support_costs[support] / len(support), support_costs[support], len(support), support), + )[: self.restricted_config.seed_normalized_touching_changed_k] + selected_supports.update(touching) + + rounds = 0 + total_added_supports = 0 + used_full_fallback = False + payload_rounds: List[Dict[str, Any]] = [] + + while True: + rounds += 1 + selected_supports |= cover_supports + restricted_support_costs = {support: support_costs[support] for support in selected_supports} + value, y_by_detector, selected_tight_supports, num_vars, num_selected_constraints = solve_primal_lp_on_supports( + detector_ids=detector_ids, + support_costs=restricted_support_costs, + record_stats=self.stats, + count_as_main_lp=True, + ) + slacks: Dict[Tuple[int, ...], float] = {} + violations: List[Tuple[float, Tuple[int, ...]]] = [] + full_tight_supports: List[Tuple[int, ...]] = [] + for support, cost in support_costs.items(): + lhs = sum(y_by_detector.get(detector, 0.0) for detector in support) + slack = cost - lhs + slacks[support] = slack + if slack < -self.restricted_config.violation_tol: + violations.append((-slack, support)) + if abs(slack) <= self.restricted_config.tight_tol: + full_tight_supports.append(support) + + payload_rounds.append( + { + "round": rounds, + "selected_constraints": len(selected_supports), + "restricted_tight_count": len(selected_tight_supports), + "full_tight_count": len(full_tight_supports), + "max_violation": 0.0 if not violations else float(max(v for v, _ in violations)), + } + ) + + if not violations: + self.stats.restricted_total_added_supports += total_added_supports + component_result = { + "value": value, + "y_by_detector": y_by_detector, + "tight_supports": tuple(sorted(full_tight_supports)), + "num_selected_constraints": len(selected_supports), + "num_rounds": rounds, + "used_full_fallback": False, + } + component_payload = { + "detectors": list(detector_ids), + "supports_total": len(support_costs), + "initial_seed_count": len(cover_supports | surviving_parent_tight), + "final_selected_constraints": len(selected_supports), + "rounds": rounds, + "used_full_fallback": False, + "parent_tight_survivors": len(surviving_parent_tight), + "cover_supports": len(cover_supports), + "round_summaries": payload_rounds, + } + return component_result, component_payload + + if rounds >= self.restricted_config.max_rounds: + if not self.restricted_config.fallback_full: + raise RuntimeError( + f"Restricted singleton LP exceeded max rounds={self.restricted_config.max_rounds} without fallback." + ) + used_full_fallback = True + self.stats.restricted_total_added_supports += total_added_supports + full_value, full_y, full_tight, _, _ = solve_primal_lp_on_supports( + detector_ids=detector_ids, + support_costs=support_costs, + record_stats=self.stats, + count_as_main_lp=True, + ) + component_result = { + "value": full_value, + "y_by_detector": full_y, + "tight_supports": tuple(sorted(full_tight)), + "num_selected_constraints": len(support_costs), + "num_rounds": rounds, + "used_full_fallback": True, + } + component_payload = { + "detectors": list(detector_ids), + "supports_total": len(support_costs), + "initial_seed_count": len(cover_supports | surviving_parent_tight), + "final_selected_constraints": len(support_costs), + "rounds": rounds, + "used_full_fallback": True, + "parent_tight_survivors": len(surviving_parent_tight), + "cover_supports": len(cover_supports), + "round_summaries": payload_rounds, + } + return component_result, component_payload + + if self.restricted_config.prune_slack: + selected_supports = { + support + for support in selected_supports + if slacks.get(support, INF) <= self.restricted_config.prune_tol or support in cover_supports + } + + violations.sort(key=lambda item: (-item[0], support_costs[item[1]], len(item[1]), item[1])) + if self.restricted_config.add_policy == "one": + to_add = [violations[0][1]] + elif self.restricted_config.add_policy == "topk": + to_add = [support for _, support in violations[: self.restricted_config.add_top_k]] + elif self.restricted_config.add_policy == "all": + to_add = [support for _, support in violations] + else: + raise ValueError(f"Unsupported add policy: {self.restricted_config.add_policy}") + new_supports = [support for support in to_add if support not in selected_supports] + total_added_supports += len(new_supports) + selected_supports.update(new_supports) + + +def xor_probability(p0: float, p1: float) -> float: + return p0 * (1 - p1) + (1 - p0) * p1 + + +def parse_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "+inf", "infinity", "+infinity"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("beam must be non-negative or 'inf'") + return float(value) + + +def parse_optional_int(text: str) -> Optional[int]: + lowered = text.strip().lower() + if lowered in {"none", "inf", "infinity"}: + return None + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("value must be non-negative or 'none'") + return value + + +def format_indices(indices: Iterable[int], prefix: str) -> str: + items = list(indices) + if not items: + return "(none)" + return " ".join(f"{prefix}{i}" for i in items) + + +def iter_dem_errors(dem: stim.DetectorErrorModel) -> Iterable[MergedError]: + for instruction in dem.flattened(): + if instruction.type != "error": + continue + probability = float(instruction.args_copy()[0]) + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError("This prototype assumes DEM probabilities are in (0, 0.5).") + detectors: set[int] = set() + observables: set[int] = set() + for target in instruction.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected DEM target: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + yield MergedError( + probability=probability, + likelihood_cost=float(-math.log(probability / (1 - probability))), + detectors=tuple(sorted(detectors)), + observables=tuple(sorted(observables)), + ) + + +def merged_errors(dem: stim.DetectorErrorModel) -> List[MergedError]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + for error in iter_dem_errors(dem): + key = (error.detectors, error.observables) + previous = errors_by_symptom.get(key) + if previous is None: + errors_by_symptom[key] = error.probability + else: + errors_by_symptom[key] = xor_probability(previous, error.probability) + + merged: List[MergedError] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError("Merged error has probability >= 0.5, giving a non-positive cost.") + merged.append( + MergedError( + probability=probability, + likelihood_cost=float(-math.log(probability / (1 - probability))), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +def build_decoder_data( + dem: stim.DetectorErrorModel, + *, + merge_errors_in_dem: bool = True, +) -> DecoderData: + errors = merged_errors(dem) if merge_errors_in_dem else list(iter_dem_errors(dem)) + detector_to_errors: List[List[int]] = [[] for _ in range(dem.num_detectors)] + for error_index, error in enumerate(errors): + for detector in error.detectors: + detector_to_errors[detector].append(error_index) + return DecoderData( + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + errors=errors, + detector_to_errors=detector_to_errors, + error_costs=np.asarray([e.likelihood_cost for e in errors], dtype=np.float64), + error_detectors=[e.detectors for e in errors], + error_observables=[e.observables for e in errors], + ) + + +def unpack_bit_packed_rows(bits: np.ndarray, count: int) -> np.ndarray: + return np.unpackbits(bits, bitorder="little", axis=1, count=count).astype(bool, copy=False) + + +def initial_detector_counts(data: DecoderData, active_detectors: np.ndarray) -> np.ndarray: + counts = np.zeros(len(data.errors), dtype=np.int32) + for detector in np.flatnonzero(active_detectors): + for error_index in data.detector_to_errors[int(detector)]: + counts[error_index] += 1 + return counts + + +def apply_error( + data: DecoderData, + active_detectors: np.ndarray, + active_detector_counts: np.ndarray, + error_index: int, +) -> Tuple[np.ndarray, np.ndarray]: + next_detectors = active_detectors.copy() + next_counts = active_detector_counts.copy() + for detector in data.error_detectors[error_index]: + if next_detectors[detector]: + next_detectors[detector] = False + delta = -1 + else: + next_detectors[detector] = True + delta = 1 + for other_error_index in data.detector_to_errors[detector]: + next_counts[other_error_index] += delta + return next_detectors, next_counts + + +def plain_detcost_for_detector( + data: DecoderData, + detector: int, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, +) -> float: + best = INF + for error_index in data.detector_to_errors[detector]: + if blocked_errors[error_index]: + continue + count = int(active_detector_counts[error_index]) + assert count > 0 + candidate = float(data.error_costs[error_index]) / count + if candidate < best: + best = candidate + return best + + +def plain_detcost_heuristic( + data: DecoderData, + active_detectors: np.ndarray, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, +) -> float: + total = 0.0 + for detector in np.flatnonzero(active_detectors): + det_cost = plain_detcost_for_detector( + data=data, + detector=int(detector), + blocked_errors=blocked_errors, + active_detector_counts=active_detector_counts, + ) + if det_cost == INF: + return INF + total += det_cost + return total + + +def build_active_support_costs( + data: DecoderData, + active_detectors: np.ndarray, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, +) -> Dict[Tuple[int, ...], float]: + support_costs: Dict[Tuple[int, ...], float] = {} + for error_index, error_detectors in enumerate(data.error_detectors): + if blocked_errors[error_index]: + continue + count = int(active_detector_counts[error_index]) + if count <= 0: + continue + support = tuple(detector for detector in error_detectors if active_detectors[detector]) + assert len(support) == count + cost = float(data.error_costs[error_index]) + previous = support_costs.get(support) + if previous is None or cost < previous: + support_costs[support] = cost + return support_costs + + +def split_support_costs_into_components( + *, + active_detector_ids: Tuple[int, ...], + support_costs: Dict[Tuple[int, ...], float], +) -> List[Tuple[Tuple[int, ...], Dict[Tuple[int, ...], float]]]: + detector_to_local = {detector: i for i, detector in enumerate(active_detector_ids)} + uf = UnionFind(len(active_detector_ids)) + for support in support_costs: + if len(support) <= 1: + continue + first = detector_to_local[support[0]] + for detector in support[1:]: + uf.union(first, detector_to_local[detector]) + + detectors_by_root: Dict[int, List[int]] = defaultdict(list) + for detector in active_detector_ids: + detectors_by_root[uf.find(detector_to_local[detector])].append(detector) + supports_by_root: Dict[int, Dict[Tuple[int, ...], float]] = defaultdict(dict) + for support, cost in support_costs.items(): + root = uf.find(detector_to_local[support[0]]) + supports_by_root[root][support] = cost + components: List[Tuple[Tuple[int, ...], Dict[Tuple[int, ...], float]]] = [] + for root, detectors in detectors_by_root.items(): + components.append((tuple(sorted(detectors)), supports_by_root[root])) + components.sort(key=lambda item: (len(item[0]), item[0])) + return components + + +def solve_primal_lp_on_supports( + *, + detector_ids: Tuple[int, ...], + support_costs: Dict[Tuple[int, ...], float], + record_stats: SingletonLPSolverStats, + count_as_main_lp: bool, +) -> Tuple[float, Dict[int, float], List[Tuple[int, ...]], int, int]: + detector_to_var = {detector: i for i, detector in enumerate(detector_ids)} + if any(not any(detector in support for support in support_costs) for detector in detector_ids): + raise RuntimeError("LP component has an uncovered detector; restricted master lost coverage.") + + row_indices: List[int] = [] + col_indices: List[int] = [] + values: List[float] = [] + rhs = np.empty(len(support_costs), dtype=np.float64) + supports = sorted(support_costs, key=lambda s: (len(s), s)) + for row, support in enumerate(supports): + rhs[row] = float(support_costs[support]) + for detector in support: + row_indices.append(row) + col_indices.append(detector_to_var[detector]) + values.append(1.0) + + a_ub = sparse.csr_matrix( + (values, (row_indices, col_indices)), + shape=(len(supports), len(detector_ids)), + dtype=np.float64, + ) + record_stats.lp_calls += 1 if count_as_main_lp else 0 + t0 = time.perf_counter() + result = linprog( + c=-np.ones(len(detector_ids), dtype=np.float64), + A_ub=a_ub, + b_ub=rhs, + bounds=[(0.0, None)] * len(detector_ids), + method="highs", + ) + elapsed = time.perf_counter() - t0 + if count_as_main_lp: + record_stats.lp_total_seconds += elapsed + if not result.success: + raise RuntimeError( + f"singleton LP solve failed: status={result.status} message={result.message}" + ) + + solution = np.asarray(result.x, dtype=np.float64) + y_by_detector = { + detector_ids[var_index]: float(solution[var_index]) + for var_index in range(len(detector_ids)) + if solution[var_index] > 1e-12 + } + tight_supports: List[Tuple[int, ...]] = [] + for row, support in enumerate(supports): + lhs = float(sum(solution[detector_to_var[detector]] for detector in support)) + if abs(float(rhs[row]) - lhs) <= 1e-8: + tight_supports.append(support) + return float(-result.fun), y_by_detector, tight_supports, len(detector_ids), len(supports) + + +def detectors_from_solution(data: DecoderData, activated_errors: Sequence[int]) -> np.ndarray: + detectors = np.zeros(data.num_detectors, dtype=bool) + for error_index in activated_errors: + for detector in data.error_detectors[error_index]: + detectors[detector] ^= True + return detectors + + +def observables_from_solution(data: DecoderData, activated_errors: Sequence[int]) -> np.ndarray: + observables = np.zeros(data.num_observables, dtype=bool) + for error_index in activated_errors: + for observable in data.error_observables[error_index]: + observables[observable] ^= True + return observables + + +def decode( + data: DecoderData, + detections: np.ndarray, + *, + det_beam: float = INF, + singleton_solver: Optional[SingletonLPHeuristic] = None, + verbose_search: bool = False, +) -> DecodeResult: + start_time = time.perf_counter() + if singleton_solver is not None: + singleton_solver.reset_stats() + + heuristic_calls = 0 + plain_heuristic_calls = 0 + projection_heuristic_calls = 0 + exact_refinement_calls = 0 + lp_reinserts = 0 + projected_nodes_generated = 0 + projected_nodes_refined = 0 + total_lp_refinement_gain = 0.0 + max_lp_refinement_gain = 0.0 + + initial_active_detectors = np.asarray(detections, dtype=bool).copy() + initial_counts = initial_detector_counts(data, initial_active_detectors) + initial_blocked = np.zeros(len(data.errors), dtype=bool) + heuristic_calls += 1 + plain_heuristic_calls += 1 + initial_heuristic = plain_detcost_heuristic( + data=data, + active_detectors=initial_active_detectors, + blocked_errors=initial_blocked, + active_detector_counts=initial_counts, + ) + if initial_heuristic == INF: + raise RuntimeError("Initial residual syndrome is infeasible under the current pruning rule.") + + initial_state = SearchState( + activated_errors=(), + blocked_errors=initial_blocked, + active_detectors=initial_active_detectors, + active_detector_counts=initial_counts, + path_cost=0.0, + heuristic_cost=initial_heuristic, + heuristic_source="plain", + exact_refined=(singleton_solver is None), + lp_solution=None, + warm_start_solution=None, + changed_detectors_from_parent=(), + ) + + priority_queue: List[Tuple[float, int, int, SearchState]] = [] + push_counter = 0 + initial_num_dets = int(initial_active_detectors.sum()) + heapq.heappush( + priority_queue, + (initial_state.path_cost + initial_state.heuristic_cost, initial_num_dets, push_counter, initial_state), + ) + push_counter += 1 + + num_pq_pushed = 1 + num_nodes_popped = 0 + max_queue_size = 1 + min_num_dets = initial_num_dets + max_num_dets = INF if det_beam == INF else min_num_dets + det_beam + + if singleton_solver is None: + heuristic_name = "plain_detcost" + else: + heuristic_name = f"opt_singleton_{singleton_solver.exact_mode}_lazy_{singleton_solver.projection_mode}" + if singleton_solver.projection_combine_max_plain: + heuristic_name += "_maxplain" + + while priority_queue: + max_queue_size = max(max_queue_size, len(priority_queue)) + f_cost, num_dets, _, state = heapq.heappop(priority_queue) + num_nodes_popped += 1 + + if num_dets > max_num_dets: + continue + + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = INF if det_beam == INF else min_num_dets + det_beam + + if verbose_search: + print( + f"nodes_popped={num_nodes_popped} len(pq)={len(priority_queue)} " + f"lp_calls={0 if singleton_solver is None else singleton_solver.stats.lp_calls} " + f"lp_reinserts={lp_reinserts} proj_generated={projected_nodes_generated} " + f"proj_refined={projected_nodes_refined} " + f"proj_unrefined_so_far={projected_nodes_generated - projected_nodes_refined} " + f"active_dets={num_dets} beam_max={max_num_dets} depth={len(state.activated_errors)} " + f"f={f_cost:.12g} g={state.path_cost:.12g} h={state.heuristic_cost:.12g} " + f"h_source={state.heuristic_source} exact_refined={state.exact_refined}" + ) + + if num_dets == 0: + elapsed_seconds = time.perf_counter() - start_time + stats = DecodeStats( + num_pq_pushed=num_pq_pushed, + num_nodes_popped=num_nodes_popped, + max_queue_size=max_queue_size, + heuristic_calls=heuristic_calls, + plain_heuristic_calls=plain_heuristic_calls, + projection_heuristic_calls=projection_heuristic_calls, + exact_refinement_calls=exact_refinement_calls, + lp_calls=0 if singleton_solver is None else singleton_solver.stats.lp_calls, + lp_reinserts=lp_reinserts, + projected_nodes_generated=projected_nodes_generated, + projected_nodes_refined=projected_nodes_refined, + projected_nodes_unrefined_at_finish=projected_nodes_generated - projected_nodes_refined, + total_lp_refinement_gain=total_lp_refinement_gain, + max_lp_refinement_gain=max_lp_refinement_gain, + lp_total_seconds=0.0 if singleton_solver is None else singleton_solver.stats.lp_total_seconds, + projection_local_lp_calls=0 if singleton_solver is None else singleton_solver.stats.projection_local_lp_calls, + projection_local_lp_seconds=0.0 if singleton_solver is None else singleton_solver.stats.projection_local_lp_seconds, + restricted_total_rounds=0 if singleton_solver is None else singleton_solver.stats.restricted_total_rounds, + restricted_total_added_supports=0 if singleton_solver is None else singleton_solver.stats.restricted_total_added_supports, + restricted_total_fallbacks=0 if singleton_solver is None else singleton_solver.stats.restricted_total_fallbacks, + full_check_calls=0 if singleton_solver is None else singleton_solver.stats.full_check_calls, + full_check_max_abs_delta=0.0 if singleton_solver is None else singleton_solver.stats.full_check_max_abs_delta, + elapsed_seconds=elapsed_seconds, + heuristic_name=heuristic_name, + ) + return DecodeResult( + activated_errors=state.activated_errors, + path_cost=state.path_cost, + stats=stats, + ) + + if singleton_solver is not None and not state.exact_refined: + heuristic_calls += 1 + exact_refinement_calls += 1 + previous_h = state.heuristic_cost + previous_source = state.heuristic_source + exact_solution, exact_payload = singleton_solver.solve_exact( + active_detectors=state.active_detectors, + blocked_errors=state.blocked_errors, + active_detector_counts=state.active_detector_counts, + warm_start_solution=state.warm_start_solution, + changed_detectors=state.changed_detectors_from_parent, + ) + exact_h = exact_solution.value + reinserted = False + discarded = False + + if exact_h == INF: + discarded = True + if previous_source == "projected": + projected_nodes_refined += 1 + else: + if exact_h + 1e-7 < previous_h: + raise AssertionError( + f"Exact singleton LP lower bound {exact_h} is below stored {previous_source} lower bound {previous_h}." + ) + delta = exact_h - previous_h + total_lp_refinement_gain += delta + max_lp_refinement_gain = max(max_lp_refinement_gain, delta) + state.heuristic_cost = exact_h + state.heuristic_source = "exact" + state.exact_refined = True + state.lp_solution = exact_solution + if previous_source == "projected": + projected_nodes_refined += 1 + if delta > HEURISTIC_EPS: + reinserted = True + lp_reinserts += 1 + heapq.heappush( + priority_queue, + (state.path_cost + state.heuristic_cost, num_dets, push_counter, state), + ) + push_counter += 1 + + if singleton_solver.logger is not None: + payload = dict(exact_payload) + payload.update( + { + "call_index": exact_refinement_calls, + "phase": "exact_refinement", + "depth": len(state.activated_errors), + "nodes_popped": num_nodes_popped, + "path_cost": state.path_cost, + "active_detector_count": num_dets, + "approx_h": previous_h, + "exact_h": exact_h, + "delta": INF if exact_h == INF else exact_h - previous_h, + "heuristic_source_before": previous_source, + "reinserted": reinserted, + "discarded": discarded, + } + ) + singleton_solver.logger.maybe_log(call_index=exact_refinement_calls, payload=payload) + + if verbose_search: + delta_text = "INF" if exact_h == INF else f"{exact_h - previous_h:.12g}" + exact_text = "INF" if exact_h == INF else f"{exact_h:.12g}" + print( + f" lp_refine approx_h={previous_h:.12g} exact_h={exact_text} delta={delta_text} " + f"vars={exact_solution.num_variables} constraints={exact_solution.num_constraints} " + f"selected={exact_solution.num_selected_constraints} rounds={exact_solution.num_rounds} " + f"tight={len(exact_solution.tight_supports)} reinserted={reinserted} discarded={discarded}" + ) + + if discarded or reinserted: + continue + + min_detector = int(np.flatnonzero(state.active_detectors)[0]) + blocked_prefix = state.blocked_errors.copy() + children_generated = 0 + children_projected = 0 + children_beam_pruned = 0 + children_infeasible = 0 + + for error_index in data.detector_to_errors[min_detector]: + blocked_prefix[error_index] = True + if state.blocked_errors[error_index]: + continue + + child_active_detectors, child_active_counts = apply_error( + data=data, + active_detectors=state.active_detectors, + active_detector_counts=state.active_detector_counts, + error_index=error_index, + ) + child_num_dets = int(child_active_detectors.sum()) + if child_num_dets > max_num_dets: + children_beam_pruned += 1 + continue + + child_blocked = blocked_prefix.copy() + child_path_cost = state.path_cost + float(data.error_costs[error_index]) + changed_detectors = tuple(sorted(data.error_detectors[error_index])) + + if singleton_solver is None: + heuristic_calls += 1 + plain_heuristic_calls += 1 + child_heuristic = plain_detcost_heuristic( + data=data, + active_detectors=child_active_detectors, + blocked_errors=child_blocked, + active_detector_counts=child_active_counts, + ) + child_source = "plain" + child_exact_refined = True + child_lp_solution = None + child_warm_start_solution = None + else: + if state.lp_solution is None: + raise AssertionError("Projected singleton heuristic requires an exact-refined parent solution.") + heuristic_calls += 1 + projection_heuristic_calls += 1 + projected_nodes_generated += 1 + children_projected += 1 + child_heuristic = singleton_solver.project_to_child( + parent_solution=state.lp_solution, + child_active_detectors=child_active_detectors, + child_blocked_errors=child_blocked, + child_active_detector_counts=child_active_counts, + changed_detectors=changed_detectors, + ) + child_source = "projected" + child_exact_refined = False + child_lp_solution = None + child_warm_start_solution = state.lp_solution + + if child_heuristic == INF: + children_infeasible += 1 + continue + + child_state = SearchState( + activated_errors=state.activated_errors + (error_index,), + blocked_errors=child_blocked, + active_detectors=child_active_detectors, + active_detector_counts=child_active_counts, + path_cost=child_path_cost, + heuristic_cost=child_heuristic, + heuristic_source=child_source, + exact_refined=child_exact_refined, + lp_solution=child_lp_solution, + warm_start_solution=child_warm_start_solution, + changed_detectors_from_parent=changed_detectors, + ) + heapq.heappush( + priority_queue, + (child_path_cost + child_heuristic, child_num_dets, push_counter, child_state), + ) + push_counter += 1 + num_pq_pushed += 1 + children_generated += 1 + + if verbose_search: + print( + f" expanded children_generated={children_generated} children_projected={children_projected} " + f"beam_pruned={children_beam_pruned} infeasible={children_infeasible}" + ) + + raise RuntimeError("Decoding failed to find any completion.") + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Prototype A* decoder for Stim detector error models. " + "Supports plain detcost, lazy full singleton LP, and a restricted-master singleton LP." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a Stim circuit file.") + parser.add_argument("--shot", type=int, default=0, help="Zero-based sampled shot index to decode.") + parser.add_argument( + "--sample-num-shots", + type=int, + default=100, + help="Number of shots to sample before selecting --shot.", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Seed passed to stim.compile_detector_sampler(...).sample(...).", + ) + parser.add_argument( + "--det-beam", + type=parse_beam, + default=INF, + help="Beam cutoff on residual detector count. Use an integer or 'inf'.", + ) + parser.add_argument( + "--opt-singleton-detcost-mode", + choices=["plain", "full", "restricted"], + default="plain", + help="Heuristic mode: plain detcost, lazy full singleton LP, or lazy restricted singleton LP.", + ) + parser.add_argument( + "--projection-mode", + choices=["plain", "parent_y", "new_only", "changed_neighborhood"], + default="changed_neighborhood", + help=( + "How to score child nodes before exact refinement. " + "'parent_y' reuses parent primal detector prices, 'new_only' solves a tiny residual LP on newly active detectors, " + "and 'changed_neighborhood' solves a tiny residual LP on a local region around the changed detectors." + ), + ) + parser.add_argument( + "--projection-combine-max-plain", + action=argparse.BooleanOptionalAction, + default=True, + help="Take max(projected child lower bound, plain detcost).", + ) + parser.add_argument( + "--merge-errors", + action=argparse.BooleanOptionalAction, + default=True, + help="Merge indistinguishable DEM errors before decoding (default: enabled).", + ) + parser.add_argument( + "--show-shot-detectors", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the sampled shot's active detector IDs before decoding.", + ) + parser.add_argument( + "--show-error-indices", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the activated error indices in the final decoding.", + ) + parser.add_argument("--verbose-search", action="store_true", help="Print per-node search diagnostics.") + parser.add_argument( + "--lp-log-path", + type=Path, + default=None, + help="Optional JSONL file for logging exact singleton-LP refinements.", + ) + parser.add_argument( + "--lp-log-top-k", + type=int, + default=12, + help="When logging exact LP refinements, include at most this many top supports.", + ) + parser.add_argument( + "--lp-log-every", + type=int, + default=1, + help="When logging exact LP refinements, only write every k-th refinement.", + ) + parser.add_argument( + "--restricted-add-policy", + choices=["one", "topk", "all"], + default="topk", + help="Violation separation policy for restricted singleton LP mode.", + ) + parser.add_argument( + "--restricted-add-top-k", + type=int, + default=3, + help="When --restricted-add-policy=topk, add this many most violated supports.", + ) + parser.add_argument( + "--restricted-max-rounds", + type=int, + default=50, + help="Maximum separation rounds before optional fallback to the full singleton LP.", + ) + parser.add_argument( + "--restricted-fallback-full", + action=argparse.BooleanOptionalAction, + default=True, + help="If restricted mode hits the round limit, fall back to the full singleton LP.", + ) + parser.add_argument( + "--restricted-prune-slack", + action=argparse.BooleanOptionalAction, + default=True, + help="Prune slack supports from the restricted master between rounds.", + ) + parser.add_argument( + "--restricted-prune-tol", + type=float, + default=1e-8, + help="Keep selected supports whose slack is at most this value.", + ) + parser.add_argument( + "--restricted-violation-tol", + type=float, + default=1e-9, + help="Violation tolerance used during separation.", + ) + parser.add_argument( + "--restricted-tight-tol", + type=float, + default=1e-8, + help="Tolerance for tagging a support as tight in the exact solution.", + ) + parser.add_argument( + "--restricted-seed-normalized-global-k", + type=int, + default=0, + help="Add this many globally cheapest supports by cost/size to the initial restricted pool.", + ) + parser.add_argument( + "--restricted-seed-normalized-touching-changed-k", + type=int, + default=2, + help="Add this many cheapest cost/size supports touching changed detectors to the initial restricted pool.", + ) + parser.add_argument( + "--full-check-every", + type=int, + default=0, + help="In restricted mode, solve the full singleton LP every k exact refinements and assert equality (0 disables).", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + if args.lp_log_every <= 0: + parser.error("--lp-log-every must be positive.") + if args.lp_log_top_k <= 0: + parser.error("--lp-log-top-k must be positive.") + if args.restricted_add_top_k <= 0: + parser.error("--restricted-add-top-k must be positive.") + if args.restricted_max_rounds <= 0: + parser.error("--restricted-max-rounds must be positive.") + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + data = build_decoder_data(dem, merge_errors_in_dem=args.merge_errors) + + singleton_solver = None + if args.opt_singleton_detcost_mode != "plain": + logger = None + if args.lp_log_path is not None: + logger = LPLogger( + args.lp_log_path, + every=args.lp_log_every, + top_k=args.lp_log_top_k, + ) + restricted_config = RestrictedMasterConfig( + add_policy=args.restricted_add_policy, + add_top_k=args.restricted_add_top_k, + violation_tol=args.restricted_violation_tol, + tight_tol=args.restricted_tight_tol, + prune_slack=args.restricted_prune_slack, + prune_tol=args.restricted_prune_tol, + seed_normalized_global_k=args.restricted_seed_normalized_global_k, + seed_normalized_touching_changed_k=args.restricted_seed_normalized_touching_changed_k, + max_rounds=args.restricted_max_rounds, + fallback_full=args.restricted_fallback_full, + full_check_every=args.full_check_every, + ) + singleton_solver = SingletonLPHeuristic( + data, + exact_mode=args.opt_singleton_detcost_mode, + projection_mode=args.projection_mode, + projection_combine_max_plain=args.projection_combine_max_plain, + restricted_config=restricted_config, + logger=logger, + ) + + dets_packed, obs_packed = circuit.compile_detector_sampler(seed=args.seed).sample( + shots=args.sample_num_shots, + separate_observables=True, + bit_packed=True, + ) + detections = unpack_bit_packed_rows(dets_packed, count=dem.num_detectors) + observables = unpack_bit_packed_rows(obs_packed, count=dem.num_observables) + + if args.shot >= detections.shape[0]: + parser.error(f"--shot={args.shot} is out of range for {detections.shape[0]} sampled shots.") + + shot_detections = detections[args.shot] + shot_observables = observables[args.shot] if observables.size else np.zeros(0, dtype=bool) + + print(f"circuit = {args.circuit}") + if singleton_solver is None: + print("heuristic = plain_detcost") + else: + print( + "heuristic = " + + f"opt_singleton_{args.opt_singleton_detcost_mode}_lazy_{args.projection_mode}" + + ("_maxplain" if args.projection_combine_max_plain else "") + ) + print(f"shot = {args.shot}") + print(f"sample_num_shots = {args.sample_num_shots}") + print(f"num_detectors = {data.num_detectors}") + print(f"num_observables = {data.num_observables}") + print(f"num_errors = {len(data.errors)}") + print(f"beam = {args.det_beam}") + if args.show_shot_detectors: + print(f"shot_detectors = {format_indices(np.flatnonzero(shot_detections), 'D')}") + + result = decode( + data=data, + detections=shot_detections, + det_beam=args.det_beam, + singleton_solver=singleton_solver, + verbose_search=args.verbose_search, + ) + + predicted_observables = observables_from_solution(data, result.activated_errors) + reproduced_detectors = detectors_from_solution(data, result.activated_errors) + if not np.array_equal(reproduced_detectors, shot_detections): + raise AssertionError("Decoded error set does not reproduce the shot's syndrome.") + + print(f"solution_size = {len(result.activated_errors)}") + print(f"solution_cost = {result.path_cost:.12g}") + if args.show_error_indices: + print(f"activated_errors = {format_indices(result.activated_errors, 'E')}") + print(f"predicted_observables = {format_indices(np.flatnonzero(predicted_observables), 'L')}") + print(f"sample_observables = {format_indices(np.flatnonzero(shot_observables), 'L')}") + print(f"observables_match = {bool(np.array_equal(predicted_observables, shot_observables))}") + print(f"num_pq_pushed = {result.stats.num_pq_pushed}") + print(f"num_nodes_popped = {result.stats.num_nodes_popped}") + print(f"max_queue_size = {result.stats.max_queue_size}") + print(f"heuristic_calls = {result.stats.heuristic_calls}") + print(f"plain_heuristic_calls = {result.stats.plain_heuristic_calls}") + print(f"projection_heuristic_calls = {result.stats.projection_heuristic_calls}") + print(f"exact_refinement_calls = {result.stats.exact_refinement_calls}") + print(f"lp_calls = {result.stats.lp_calls}") + print(f"lp_reinserts = {result.stats.lp_reinserts}") + print(f"projected_nodes_generated = {result.stats.projected_nodes_generated}") + print(f"projected_nodes_refined = {result.stats.projected_nodes_refined}") + print(f"projected_nodes_unrefined_at_finish = {result.stats.projected_nodes_unrefined_at_finish}") + print(f"total_lp_refinement_gain = {result.stats.total_lp_refinement_gain:.12g}") + print(f"max_lp_refinement_gain = {result.stats.max_lp_refinement_gain:.12g}") + print(f"lp_total_seconds = {result.stats.lp_total_seconds:.6f}") + print(f"projection_local_lp_calls = {result.stats.projection_local_lp_calls}") + print(f"projection_local_lp_seconds = {result.stats.projection_local_lp_seconds:.6f}") + print(f"restricted_total_rounds = {result.stats.restricted_total_rounds}") + print(f"restricted_total_added_supports = {result.stats.restricted_total_added_supports}") + print(f"restricted_total_fallbacks = {result.stats.restricted_total_fallbacks}") + print(f"full_check_calls = {result.stats.full_check_calls}") + print(f"full_check_max_abs_delta = {result.stats.full_check_max_abs_delta:.12g}") + print(f"elapsed_seconds = {result.stats.elapsed_seconds:.6f}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/astar_prototype_subset_detcost.py b/src/py/astar/astar_prototype_subset_detcost.py new file mode 100644 index 0000000..f405240 --- /dev/null +++ b/src/py/astar/astar_prototype_subset_detcost.py @@ -0,0 +1,1071 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder with detcost and subset-LP heuristics. + +This script keeps the basic search structure of the original prototype while +adding a small CLI and a family of stronger admissible heuristics. + +Heuristic modes: + --opt-subset-detcost-size 0 plain detcost + --opt-subset-detcost-size 1 optimal singleton LP + --opt-subset-detcost-size 2 optimal LP over singletons and 2-detector subsets + --opt-subset-detcost-size 3 optimal LP over singletons and 2/3-detector subsets + +The subset library is the small-subset closure of DEM supports: + * every singleton detector subset, and + * every nonempty subset of D(e) of size at most N, for each DEM error e. + +For a library subset S, the local decoder only sees the restriction of errors to +S. Because N is intended to be small (<= 3 in practice), all minimal local +pattern resolutions can be precomputed once. +""" + +from __future__ import annotations + +import argparse +import heapq +import itertools +import json +import math +import time +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +import stim +from scipy import sparse +from scipy.optimize import linprog + +INF = math.inf + + +@dataclass(frozen=True) +class MergedError: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class DecoderData: + num_detectors: int + num_observables: int + errors: List[MergedError] + detector_to_errors: List[List[int]] + error_costs: np.ndarray + error_detectors: List[Tuple[int, ...]] + error_detector_sets: List[frozenset[int]] + error_observables: List[Tuple[int, ...]] + + +@dataclass(frozen=True) +class SubsetLibraryEntry: + subset_id: int + detectors: Tuple[int, ...] + pattern_to_errors: Dict[int, Tuple[int, ...]] + resolution_combos: Dict[int, Tuple[Tuple[int, ...], ...]] + + +@dataclass +class ActiveSubsetRecord: + subset_id: int + detectors: Tuple[int, ...] + size: int + target_mask: int + available_patterns: Dict[int, Tuple[int, ...]] + feasible_combos: Tuple[Tuple[int, ...], ...] + + +@dataclass +class SearchState: + activated_errors: Tuple[int, ...] + blocked_errors: np.ndarray + active_detectors: np.ndarray + active_detector_counts: np.ndarray + path_cost: float + + +@dataclass +class DecodeStats: + num_pq_pushed: int + num_nodes_popped: int + max_queue_size: int + heuristic_calls: int + lp_calls: int + lp_total_seconds: float + elapsed_seconds: float + heuristic_name: str + + +@dataclass +class DecodeResult: + activated_errors: Tuple[int, ...] + path_cost: float + stats: DecodeStats + + +class UnionFind: + def __init__(self, size: int) -> None: + self.parent = list(range(size)) + self.rank = [0] * size + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra = self.find(a) + rb = self.find(b) + if ra == rb: + return + if self.rank[ra] < self.rank[rb]: + self.parent[ra] = rb + elif self.rank[ra] > self.rank[rb]: + self.parent[rb] = ra + else: + self.parent[rb] = ra + self.rank[ra] += 1 + + +class LPLogger: + def __init__(self, path: Path, *, every: int = 1, top_k: int = 10) -> None: + self.path = path + self.every = max(1, every) + self.top_k = max(1, top_k) + self.path.parent.mkdir(parents=True, exist_ok=True) + # Truncate eagerly so repeated runs do not append by accident. + self.path.write_text("") + + def maybe_log(self, *, call_index: int, payload: Dict[str, Any]) -> None: + if call_index % self.every != 0: + return + with self.path.open("a", encoding="utf-8") as f: + f.write(json.dumps(payload, sort_keys=True) + "\n") + + +def parse_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "+inf", "infinity", "+infinity"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("beam must be non-negative or 'inf'") + return float(value) + + +def format_indices(indices: Iterable[int], prefix: str) -> str: + items = list(indices) + if not items: + return "(none)" + return " ".join(f"{prefix}{i}" for i in items) + + +def xor_probability(p0: float, p1: float) -> float: + return p0 * (1 - p1) + (1 - p0) * p1 + + +def iter_dem_errors(dem: stim.DetectorErrorModel) -> Iterable[MergedError]: + for instruction in dem.flattened(): + if instruction.type != "error": + continue + probability = float(instruction.args_copy()[0]) + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + "This prototype assumes detector-error-model probabilities are in (0, 0.5)." + ) + detectors: set[int] = set() + observables: set[int] = set() + for target in instruction.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected DEM target: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + yield MergedError( + probability=probability, + likelihood_cost=float(-math.log(probability / (1 - probability))), + detectors=tuple(sorted(detectors)), + observables=tuple(sorted(observables)), + ) + + +def merged_errors(dem: stim.DetectorErrorModel) -> List[MergedError]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + for error in iter_dem_errors(dem): + key = (error.detectors, error.observables) + previous = errors_by_symptom.get(key) + if previous is None: + errors_by_symptom[key] = error.probability + else: + errors_by_symptom[key] = xor_probability(previous, error.probability) + + merged: List[MergedError] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + "Merged error has probability >= 0.5, which would give a non-positive cost." + ) + merged.append( + MergedError( + probability=probability, + likelihood_cost=float(-math.log(probability / (1 - probability))), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +def build_decoder_data( + dem: stim.DetectorErrorModel, + *, + merge_errors_in_dem: bool = True, +) -> DecoderData: + errors = merged_errors(dem) if merge_errors_in_dem else list(iter_dem_errors(dem)) + detector_to_errors: List[List[int]] = [[] for _ in range(dem.num_detectors)] + for ei, error in enumerate(errors): + for d in error.detectors: + detector_to_errors[d].append(ei) + return DecoderData( + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + errors=errors, + detector_to_errors=detector_to_errors, + error_costs=np.asarray([e.likelihood_cost for e in errors], dtype=np.float64), + error_detectors=[e.detectors for e in errors], + error_detector_sets=[frozenset(e.detectors) for e in errors], + error_observables=[e.observables for e in errors], + ) + + +def unpack_bit_packed_rows(bits: np.ndarray, count: int) -> np.ndarray: + return np.unpackbits(bits, bitorder="little", axis=1, count=count).astype(bool, copy=False) + + +def initial_detector_counts(data: DecoderData, active_detectors: np.ndarray) -> np.ndarray: + counts = np.zeros(len(data.errors), dtype=np.int32) + for d in np.flatnonzero(active_detectors): + for ei in data.detector_to_errors[int(d)]: + counts[ei] += 1 + return counts + + +def apply_error( + data: DecoderData, + active_detectors: np.ndarray, + active_detector_counts: np.ndarray, + error_index: int, +) -> Tuple[np.ndarray, np.ndarray]: + next_detectors = active_detectors.copy() + next_counts = active_detector_counts.copy() + for d in data.error_detectors[error_index]: + if next_detectors[d]: + next_detectors[d] = False + delta = -1 + else: + next_detectors[d] = True + delta = 1 + for other_error_index in data.detector_to_errors[d]: + next_counts[other_error_index] += delta + return next_detectors, next_counts + + +def plain_detcost_for_detector( + data: DecoderData, + detector: int, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, +) -> float: + best = INF + for ei in data.detector_to_errors[detector]: + if blocked_errors[ei]: + continue + count = int(active_detector_counts[ei]) + assert count > 0 + candidate = float(data.error_costs[ei]) / count + if candidate < best: + best = candidate + return best + + +def plain_detcost_heuristic( + data: DecoderData, + active_detectors: np.ndarray, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, +) -> float: + total = 0.0 + for d in np.flatnonzero(active_detectors): + det_cost = plain_detcost_for_detector( + data=data, + detector=int(d), + blocked_errors=blocked_errors, + active_detector_counts=active_detector_counts, + ) + if det_cost == INF: + return INF + total += det_cost + return total + + +def compute_minimal_resolution_combos( + available_pattern_masks: Iterable[int], + subset_size: int, +) -> Dict[int, Tuple[Tuple[int, ...], ...]]: + """Precompute inclusion-minimal local pattern combinations for each target. + + For a fixed subset S of size k, an error only matters through its nonzero local + pattern D(e)∩S, represented as a bit-mask in {1, ..., 2^k-1}. Because local + budgets are nonnegative, an optimal local resolution never needs to use the same + local pattern twice, and any combo that strictly contains another combo with the + same XOR target is dominated. + """ + + patterns = tuple(sorted(set(available_pattern_masks))) + combos_by_target: Dict[int, List[Tuple[int, ...]]] = { + target: [] for target in range(1, 1 << subset_size) + } + for r in range(1, min(len(patterns), subset_size) + 1): + for combo in itertools.combinations(patterns, r): + target_mask = 0 + for pattern_mask in combo: + target_mask ^= pattern_mask + if target_mask == 0: + continue + combo_set = set(combo) + existing = combos_by_target[target_mask] + keep = True + survivors: List[Tuple[int, ...]] = [] + for old_combo in existing: + old_set = set(old_combo) + if combo_set.issuperset(old_set): + keep = False + survivors.append(old_combo) + elif old_set.issuperset(combo_set): + continue + else: + survivors.append(old_combo) + if keep: + survivors.append(combo) + survivors.sort(key=lambda x: (len(x), x)) + combos_by_target[target_mask] = survivors + return { + target_mask: tuple(combos) + for target_mask, combos in combos_by_target.items() + if combos + } + + +@dataclass +class SubsetLibrary: + max_subset_size: int + entries: List[SubsetLibraryEntry] + subsets_by_detector: List[List[int]] + num_subsets_by_size: Dict[int, int] + + +def build_subset_library(data: DecoderData, max_subset_size: int) -> SubsetLibrary: + library_keys: set[Tuple[int, ...]] = set() + if max_subset_size >= 1: + for detector in range(data.num_detectors): + library_keys.add((detector,)) + + for detectors in data.error_detectors: + limit = min(max_subset_size, len(detectors)) + for subset_size in range(1, limit + 1): + for subset_detectors in itertools.combinations(detectors, subset_size): + library_keys.add(tuple(subset_detectors)) + + subsets_by_detector: List[List[int]] = [[] for _ in range(data.num_detectors)] + entries: List[SubsetLibraryEntry] = [] + num_subsets_by_size: Dict[int, int] = defaultdict(int) + + for subset_id, subset_detectors in enumerate(sorted(library_keys, key=lambda t: (len(t), t))): + pattern_to_errors: Dict[int, List[int]] = defaultdict(list) + for error_index, detector_set in enumerate(data.error_detector_sets): + pattern_mask = 0 + for bit_index, detector in enumerate(subset_detectors): + if detector in detector_set: + pattern_mask |= 1 << bit_index + if pattern_mask != 0: + pattern_to_errors[pattern_mask].append(error_index) + frozen_pattern_to_errors = { + pattern_mask: tuple(error_indices) + for pattern_mask, error_indices in pattern_to_errors.items() + } + entry = SubsetLibraryEntry( + subset_id=subset_id, + detectors=subset_detectors, + pattern_to_errors=frozen_pattern_to_errors, + resolution_combos=compute_minimal_resolution_combos( + available_pattern_masks=frozen_pattern_to_errors.keys(), + subset_size=len(subset_detectors), + ), + ) + entries.append(entry) + num_subsets_by_size[len(subset_detectors)] += 1 + for detector in subset_detectors: + subsets_by_detector[detector].append(subset_id) + + return SubsetLibrary( + max_subset_size=max_subset_size, + entries=entries, + subsets_by_detector=subsets_by_detector, + num_subsets_by_size=dict(sorted(num_subsets_by_size.items())), + ) + + +@dataclass +class SubsetLPHeuristicStats: + call_count: int = 0 + lp_call_count: int = 0 + lp_total_seconds: float = 0.0 + + +class SubsetLPHeuristic: + def __init__( + self, + data: DecoderData, + subset_library: SubsetLibrary, + *, + logger: Optional[LPLogger] = None, + ) -> None: + self.data = data + self.subset_library = subset_library + self.logger = logger + self.stats = SubsetLPHeuristicStats() + + def evaluate( + self, + active_detectors: np.ndarray, + blocked_errors: np.ndarray, + *, + context: Optional[Dict[str, Any]] = None, + ) -> float: + self.stats.call_count += 1 + self.stats.lp_call_count += 1 + t0 = time.perf_counter() + + active_subset_ids: set[int] = set() + for detector in np.flatnonzero(active_detectors): + active_subset_ids.update(self.subset_library.subsets_by_detector[int(detector)]) + + subset_records: List[ActiveSubsetRecord] = [] + error_to_subset_positions: Dict[int, List[int]] = defaultdict(list) + + for subset_id in sorted(active_subset_ids): + entry = self.subset_library.entries[subset_id] + target_mask = 0 + for bit_index, detector in enumerate(entry.detectors): + if active_detectors[detector]: + target_mask |= 1 << bit_index + if target_mask == 0: + continue + + available_patterns: Dict[int, Tuple[int, ...]] = {} + relevant_errors: set[int] = set() + for pattern_mask, error_indices in entry.pattern_to_errors.items(): + kept = tuple(error_index for error_index in error_indices if not blocked_errors[error_index]) + if kept: + available_patterns[pattern_mask] = kept + relevant_errors.update(kept) + + feasible_combos = tuple( + combo + for combo in entry.resolution_combos.get(target_mask, ()) + if all(pattern_mask in available_patterns for pattern_mask in combo) + ) + if not feasible_combos: + self.stats.lp_total_seconds += time.perf_counter() - t0 + return INF + + record = ActiveSubsetRecord( + subset_id=subset_id, + detectors=entry.detectors, + size=len(entry.detectors), + target_mask=target_mask, + available_patterns=available_patterns, + feasible_combos=feasible_combos, + ) + subset_position = len(subset_records) + subset_records.append(record) + for error_index in sorted(relevant_errors): + error_to_subset_positions[error_index].append(subset_position) + + if not subset_records: + elapsed = time.perf_counter() - t0 + self.stats.lp_total_seconds += elapsed + if self.logger is not None: + payload: Dict[str, Any] = { + "call_index": self.stats.call_count, + "objective": 0.0, + "solve_seconds": elapsed, + "num_active_subsets": 0, + "num_components": 0, + } + if context is not None: + payload.update(context) + self.logger.maybe_log(call_index=self.stats.call_count, payload=payload) + return 0.0 + + component_uf = UnionFind(len(subset_records)) + for subset_positions in error_to_subset_positions.values(): + for position in subset_positions[1:]: + component_uf.union(subset_positions[0], position) + component_to_subset_positions: Dict[int, List[int]] = defaultdict(list) + for subset_position in range(len(subset_records)): + component_to_subset_positions[component_uf.find(subset_position)].append(subset_position) + + total_objective = 0.0 + total_num_variables = 0 + total_num_constraints = 0 + contribution_by_size: Dict[int, float] = defaultdict(float) + budget_by_size: Dict[int, float] = defaultdict(float) + active_subset_count_by_size: Dict[int, int] = defaultdict(int) + top_subset_records: List[Dict[str, Any]] = [] + + for component_positions in component_to_subset_positions.values(): + y_var: Dict[int, int] = {} + u_var: Dict[Tuple[int, int], int] = {} + error_to_u_vars: Dict[int, List[int]] = defaultdict(list) + + next_var_index = 0 + for subset_position in component_positions: + y_var[subset_position] = next_var_index + next_var_index += 1 + for subset_position in component_positions: + record = subset_records[subset_position] + active_subset_count_by_size[record.size] += 1 + for pattern_mask, error_indices in sorted(record.available_patterns.items()): + variable_index = next_var_index + u_var[(subset_position, pattern_mask)] = variable_index + next_var_index += 1 + for error_index in error_indices: + error_to_u_vars[error_index].append(variable_index) + + row_indices: List[int] = [] + col_indices: List[int] = [] + values: List[float] = [] + rhs: List[float] = [] + + for error_index, variable_indices in sorted(error_to_u_vars.items()): + row = len(rhs) + rhs.append(float(self.data.error_costs[error_index])) + for variable_index in variable_indices: + row_indices.append(row) + col_indices.append(variable_index) + values.append(1.0) + + for subset_position in component_positions: + record = subset_records[subset_position] + y_index = y_var[subset_position] + for combo in record.feasible_combos: + row = len(rhs) + rhs.append(0.0) + row_indices.append(row) + col_indices.append(y_index) + values.append(1.0) + for pattern_mask in combo: + row_indices.append(row) + col_indices.append(u_var[(subset_position, pattern_mask)]) + values.append(-1.0) + + total_num_variables += next_var_index + total_num_constraints += len(rhs) + + a_ub = sparse.csr_matrix( + (values, (row_indices, col_indices)), + shape=(len(rhs), next_var_index), + dtype=np.float64, + ) + objective = np.zeros(next_var_index, dtype=np.float64) + for subset_position in component_positions: + objective[y_var[subset_position]] = -1.0 + + result = linprog( + c=objective, + A_ub=a_ub, + b_ub=np.asarray(rhs, dtype=np.float64), + bounds=[(0.0, None)] * next_var_index, + method="highs", + ) + if not result.success: + raise RuntimeError( + f"subset detcost LP solve failed: status={result.status} message={result.message}" + ) + total_objective += float(-result.fun) + solution = np.asarray(result.x, dtype=np.float64) + + for subset_position in component_positions: + record = subset_records[subset_position] + y_value = float(solution[y_var[subset_position]]) + total_budget = float( + sum(solution[u_var[(subset_position, pattern_mask)]] for pattern_mask in record.available_patterns) + ) + contribution_by_size[record.size] += y_value + budget_by_size[record.size] += total_budget + pattern_values = [ + { + "pattern_detectors": [ + detector + for bit_index, detector in enumerate(record.detectors) + if pattern_mask & (1 << bit_index) + ], + "u": float(solution[u_var[(subset_position, pattern_mask)]]), + "num_allowed_errors": len(record.available_patterns[pattern_mask]), + } + for pattern_mask in sorted(record.available_patterns) + if solution[u_var[(subset_position, pattern_mask)]] > 1e-12 + ] + top_subset_records.append( + { + "subset_detectors": list(record.detectors), + "subset_size": record.size, + "target_active_detectors": [ + detector + for bit_index, detector in enumerate(record.detectors) + if record.target_mask & (1 << bit_index) + ], + "y": y_value, + "total_budget": total_budget, + "num_available_patterns": len(record.available_patterns), + "num_feasible_resolution_combos": len(record.feasible_combos), + "patterns": pattern_values, + } + ) + + elapsed = time.perf_counter() - t0 + self.stats.lp_total_seconds += elapsed + + if self.logger is not None: + top_subset_records.sort(key=lambda item: (-item["y"], -item["total_budget"], item["subset_detectors"])) + payload = { + "call_index": self.stats.call_count, + "objective": total_objective, + "solve_seconds": elapsed, + "num_active_subsets": len(subset_records), + "num_active_subsets_by_size": { + str(size): active_subset_count_by_size[size] for size in sorted(active_subset_count_by_size) + }, + "num_components": len(component_to_subset_positions), + "num_variables": total_num_variables, + "num_constraints": total_num_constraints, + "contribution_by_subset_size": { + str(size): contribution_by_size[size] for size in sorted(contribution_by_size) + }, + "allocated_budget_by_subset_size": { + str(size): budget_by_size[size] for size in sorted(budget_by_size) + }, + "top_subsets": top_subset_records[: self.logger.top_k], + } + if context is not None: + payload.update(context) + self.logger.maybe_log(call_index=self.stats.call_count, payload=payload) + + return total_objective + +def compute_heuristic( + data: DecoderData, + active_detectors: np.ndarray, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, + *, + opt_subset_solver: Optional[SubsetLPHeuristic], + context: Optional[Dict[str, Any]] = None, +) -> float: + if opt_subset_solver is None: + return plain_detcost_heuristic( + data=data, + active_detectors=active_detectors, + blocked_errors=blocked_errors, + active_detector_counts=active_detector_counts, + ) + del active_detector_counts + return opt_subset_solver.evaluate( + active_detectors=active_detectors, + blocked_errors=blocked_errors, + context=context, + ) + + +def detectors_from_solution(data: DecoderData, activated_errors: Sequence[int]) -> np.ndarray: + detectors = np.zeros(data.num_detectors, dtype=bool) + for error_index in activated_errors: + for detector in data.error_detectors[error_index]: + detectors[detector] ^= True + return detectors + + +def observables_from_solution(data: DecoderData, activated_errors: Sequence[int]) -> np.ndarray: + observables = np.zeros(data.num_observables, dtype=bool) + for error_index in activated_errors: + for observable in data.error_observables[error_index]: + observables[observable] ^= True + return observables + + +def decode( + data: DecoderData, + detections: np.ndarray, + *, + det_beam: float = INF, + opt_subset_solver: Optional[SubsetLPHeuristic] = None, + verbose_search: bool = False, +) -> DecodeResult: + start_time = time.perf_counter() + initial_active_detectors = np.asarray(detections, dtype=bool).copy() + initial_counts = initial_detector_counts(data, initial_active_detectors) + initial_blocked = np.zeros(len(data.errors), dtype=bool) + initial_path_cost = 0.0 + initial_heuristic = compute_heuristic( + data=data, + active_detectors=initial_active_detectors, + blocked_errors=initial_blocked, + active_detector_counts=initial_counts, + opt_subset_solver=opt_subset_solver, + context={ + "phase": "initial", + "depth": 0, + "nodes_popped": 0, + "path_cost": 0.0, + "active_detector_count": int(initial_active_detectors.sum()), + }, + ) + if initial_heuristic == INF: + raise RuntimeError("Initial residual syndrome is infeasible under the current pruning rule.") + + initial_state = SearchState( + activated_errors=(), + blocked_errors=initial_blocked, + active_detectors=initial_active_detectors, + active_detector_counts=initial_counts, + path_cost=initial_path_cost, + ) + + priority_queue: List[Tuple[float, int, int, SearchState]] = [] + push_counter = 0 + initial_num_dets = int(initial_active_detectors.sum()) + heapq.heappush( + priority_queue, + (initial_path_cost + initial_heuristic, initial_num_dets, push_counter, initial_state), + ) + push_counter += 1 + + num_pq_pushed = 1 + num_nodes_popped = 0 + max_queue_size = 1 + min_num_dets = initial_num_dets + max_num_dets = INF if det_beam == INF else min_num_dets + det_beam + + heuristic_name = ( + f"opt_subset_detcost_size_{opt_subset_solver.subset_library.max_subset_size}" + if opt_subset_solver is not None + else "plain_detcost" + ) + + while priority_queue: + max_queue_size = max(max_queue_size, len(priority_queue)) + f_cost, num_dets, _, state = heapq.heappop(priority_queue) + num_nodes_popped += 1 + + if num_dets > max_num_dets: + continue + + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = INF if det_beam == INF else min_num_dets + det_beam + + if verbose_search: + print( + f"nodes_popped={num_nodes_popped} len(pq)={len(priority_queue)} " + f"active_dets={num_dets} beam_max={max_num_dets} depth={len(state.activated_errors)} " + f"f={f_cost:.12g} g={state.path_cost:.12g}" + ) + + if num_dets == 0: + elapsed_seconds = time.perf_counter() - start_time + heuristic_calls = 0 if opt_subset_solver is None else opt_subset_solver.stats.call_count + lp_calls = 0 if opt_subset_solver is None else opt_subset_solver.stats.lp_call_count + lp_total_seconds = 0.0 if opt_subset_solver is None else opt_subset_solver.stats.lp_total_seconds + return DecodeResult( + activated_errors=state.activated_errors, + path_cost=state.path_cost, + stats=DecodeStats( + num_pq_pushed=num_pq_pushed, + num_nodes_popped=num_nodes_popped, + max_queue_size=max_queue_size, + heuristic_calls=heuristic_calls, + lp_calls=lp_calls, + lp_total_seconds=lp_total_seconds, + elapsed_seconds=elapsed_seconds, + heuristic_name=heuristic_name, + ), + ) + + min_detector = int(np.flatnonzero(state.active_detectors)[0]) + blocked_prefix = state.blocked_errors.copy() + for error_index in data.detector_to_errors[min_detector]: + blocked_prefix[error_index] = True + if state.blocked_errors[error_index]: + continue + + child_active_detectors, child_active_counts = apply_error( + data=data, + active_detectors=state.active_detectors, + active_detector_counts=state.active_detector_counts, + error_index=error_index, + ) + child_num_dets = int(child_active_detectors.sum()) + if child_num_dets > max_num_dets: + continue + + child_blocked = blocked_prefix.copy() + child_path_cost = state.path_cost + float(data.error_costs[error_index]) + child_heuristic = compute_heuristic( + data=data, + active_detectors=child_active_detectors, + blocked_errors=child_blocked, + active_detector_counts=child_active_counts, + opt_subset_solver=opt_subset_solver, + context={ + "phase": "child", + "depth": len(state.activated_errors) + 1, + "nodes_popped": num_nodes_popped, + "path_cost": child_path_cost, + "active_detector_count": child_num_dets, + "chosen_error": error_index, + "min_detector": min_detector, + }, + ) + if child_heuristic == INF: + continue + + child_state = SearchState( + activated_errors=state.activated_errors + (error_index,), + blocked_errors=child_blocked, + active_detectors=child_active_detectors, + active_detector_counts=child_active_counts, + path_cost=child_path_cost, + ) + heapq.heappush( + priority_queue, + ( + child_path_cost + child_heuristic, + child_num_dets, + push_counter, + child_state, + ), + ) + push_counter += 1 + num_pq_pushed += 1 + + raise RuntimeError("Decoding failed to find any completion.") + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Prototype A* decoder for Stim detector error models. " + "Supports plain detcost and subset-based LP lower bounds." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a Stim circuit file.") + parser.add_argument( + "--shot", + type=int, + default=0, + help="Zero-based sampled shot index to decode.", + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=100, + help="Number of shots to sample before selecting --shot.", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Seed passed to stim.compile_detector_sampler(...).sample(...).", + ) + parser.add_argument( + "--det-beam", + type=parse_beam, + default=INF, + help="Beam cutoff on the residual detector count. Use an integer or 'inf'.", + ) + parser.add_argument( + "--opt-subset-detcost-size", + type=int, + default=0, + help=( + "Use the subset-based LP heuristic with library subsets of size at most N. " + "Use 0 for plain detcost, 1 for the optimal singleton LP, etc." + ), + ) + parser.add_argument( + "--merge-errors", + action=argparse.BooleanOptionalAction, + default=True, + help="Merge indistinguishable DEM errors before decoding (default: enabled).", + ) + parser.add_argument( + "--show-shot-detectors", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the sampled shot's active detector IDs before decoding.", + ) + parser.add_argument( + "--show-error-indices", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the activated error indices in the final decoding.", + ) + parser.add_argument( + "--verbose-search", + action="store_true", + help="Print per-node search diagnostics.", + ) + parser.add_argument( + "--lp-log-path", + type=Path, + default=None, + help="Optional JSONL file for logging details of each subset-LP solve.", + ) + parser.add_argument( + "--lp-log-top-k", + type=int, + default=10, + help="When logging LP solves, include at most this many top subsets per solve.", + ) + parser.add_argument( + "--lp-log-every", + type=int, + default=1, + help="When logging LP solves, only write every k-th solve.", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + if args.opt_subset_detcost_size < 0: + parser.error("--opt-subset-detcost-size must be non-negative.") + if args.lp_log_every <= 0: + parser.error("--lp-log-every must be positive.") + if args.lp_log_top_k <= 0: + parser.error("--lp-log-top-k must be positive.") + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + data = build_decoder_data(dem, merge_errors_in_dem=args.merge_errors) + + subset_library = None + subset_solver = None + if args.opt_subset_detcost_size > 0: + subset_library = build_subset_library(data, args.opt_subset_detcost_size) + lp_logger = None + if args.lp_log_path is not None: + lp_logger = LPLogger( + args.lp_log_path, + every=args.lp_log_every, + top_k=args.lp_log_top_k, + ) + subset_solver = SubsetLPHeuristic(data, subset_library, logger=lp_logger) + + dets_packed, obs_packed = circuit.compile_detector_sampler(seed=args.seed).sample( + shots=args.sample_num_shots, + separate_observables=True, + bit_packed=True, + ) + detections = unpack_bit_packed_rows(dets_packed, count=dem.num_detectors) + observables = unpack_bit_packed_rows(obs_packed, count=dem.num_observables) + + if args.shot >= detections.shape[0]: + parser.error(f"--shot={args.shot} is out of range for {detections.shape[0]} sampled shots.") + + shot_detections = detections[args.shot] + shot_observables = observables[args.shot] if observables.size else np.zeros(0, dtype=bool) + + print(f"circuit = {args.circuit}") + print( + "heuristic = " + + ( + "plain_detcost" + if subset_solver is None + else f"opt_subset_detcost_size_{subset_library.max_subset_size}" + ) + ) + print(f"shot = {args.shot}") + print(f"sample_num_shots = {args.sample_num_shots}") + print(f"num_detectors = {data.num_detectors}") + print(f"num_observables = {data.num_observables}") + print(f"num_errors = {len(data.errors)}") + print(f"beam = {args.det_beam}") + if subset_library is not None: + print(f"subset_library_size = {len(subset_library.entries)}") + print( + "subset_library_by_size = " + + ", ".join( + f"{size}:{count}" for size, count in subset_library.num_subsets_by_size.items() + ) + ) + if args.show_shot_detectors: + print(f"shot_detectors = {format_indices(np.flatnonzero(shot_detections), 'D')}") + + result = decode( + data=data, + detections=shot_detections, + det_beam=args.det_beam, + opt_subset_solver=subset_solver, + verbose_search=args.verbose_search, + ) + + predicted_observables = observables_from_solution(data, result.activated_errors) + reproduced_detectors = detectors_from_solution(data, result.activated_errors) + if not np.array_equal(reproduced_detectors, shot_detections): + raise AssertionError("Decoded error set does not reproduce the shot's syndrome.") + + print(f"solution_size = {len(result.activated_errors)}") + print(f"solution_cost = {result.path_cost:.12g}") + if args.show_error_indices: + print(f"activated_errors = {format_indices(result.activated_errors, 'E')}") + print(f"predicted_observables = {format_indices(np.flatnonzero(predicted_observables), 'L')}") + print(f"sample_observables = {format_indices(np.flatnonzero(shot_observables), 'L')}") + print(f"observables_match = {bool(np.array_equal(predicted_observables, shot_observables))}") + print(f"num_pq_pushed = {result.stats.num_pq_pushed}") + print(f"num_nodes_popped = {result.stats.num_nodes_popped}") + print(f"max_queue_size = {result.stats.max_queue_size}") + print(f"heuristic_calls = {result.stats.heuristic_calls}") + print(f"lp_calls = {result.stats.lp_calls}") + print(f"lp_total_seconds = {result.stats.lp_total_seconds:.6f}") + print(f"elapsed_seconds = {result.stats.elapsed_seconds:.6f}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/astar_prototype_subset_detcost_lazy.py b/src/py/astar/astar_prototype_subset_detcost_lazy.py new file mode 100644 index 0000000..8bff114 --- /dev/null +++ b/src/py/astar/astar_prototype_subset_detcost_lazy.py @@ -0,0 +1,1314 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder with lazy subset-LP refinement. + +Heuristic modes: + --opt-subset-detcost-size 0 plain detcost + --opt-subset-detcost-size 1 lazy optimal singleton LP + --opt-subset-detcost-size 2 lazy optimal LP over size-1/2 subsets + --opt-subset-detcost-size 3 lazy optimal LP over size-1/2/3 subsets + +For subset size N > 0, the search uses lazy refinement: + * nodes are first inserted using a cheap lower bound; + * when popped, the exact subset LP is solved; + * if the exact LP raises the node key, the node is reinserted; + * expanded nodes project their exact subset-pattern prices onto children. + +The projection step is the main subtlety relative to the singleton case. +The exact parent LP stores prices u_{S,t} for subset/pattern pairs. For a child, +we keep inherited u_{S,t} values on patterns still available, zero out patterns +that have become unavailable, assign zero to newly active subsets, and recompute +child y_S values as the minimum cost of a feasible local signature decomposition +under those inherited prices. +""" + +from __future__ import annotations + +import argparse +import heapq +import itertools +import json +import math +import time +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +import stim +from scipy import sparse +from scipy.optimize import linprog + +INF = math.inf +HEURISTIC_EPS = 1e-9 + + +@dataclass(frozen=True) +class MergedError: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class DecoderData: + num_detectors: int + num_observables: int + errors: List[MergedError] + detector_to_errors: List[List[int]] + error_costs: np.ndarray + error_detectors: List[Tuple[int, ...]] + error_detector_sets: List[frozenset[int]] + error_observables: List[Tuple[int, ...]] + + +@dataclass(frozen=True) +class SubsetLibraryEntry: + subset_id: int + detectors: Tuple[int, ...] + pattern_to_errors: Dict[int, Tuple[int, ...]] + resolution_combos: Dict[int, Tuple[Tuple[int, ...], ...]] + + +@dataclass +class ActiveSubsetRecord: + subset_id: int + detectors: Tuple[int, ...] + size: int + target_mask: int + available_patterns: Dict[int, Tuple[int, ...]] + feasible_combos: Tuple[Tuple[int, ...], ...] + + +@dataclass +class SearchState: + activated_errors: Tuple[int, ...] + blocked_errors: np.ndarray + active_detectors: np.ndarray + active_detector_counts: np.ndarray + path_cost: float + heuristic_cost: float + heuristic_source: str + exact_refined: bool + lp_solution: Optional["SubsetLPSolution"] = None + + +@dataclass +class DecodeStats: + num_pq_pushed: int + num_nodes_popped: int + max_queue_size: int + heuristic_calls: int + plain_heuristic_calls: int + projection_heuristic_calls: int + exact_refinement_calls: int + lp_calls: int + lp_reinserts: int + projected_nodes_generated: int + projected_nodes_refined: int + projected_nodes_unrefined_at_finish: int + total_lp_refinement_gain: float + max_lp_refinement_gain: float + lp_total_seconds: float + elapsed_seconds: float + heuristic_name: str + + +@dataclass +class DecodeResult: + activated_errors: Tuple[int, ...] + path_cost: float + stats: DecodeStats + + +class UnionFind: + def __init__(self, size: int) -> None: + self.parent = list(range(size)) + self.rank = [0] * size + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra = self.find(a) + rb = self.find(b) + if ra == rb: + return + if self.rank[ra] < self.rank[rb]: + self.parent[ra] = rb + elif self.rank[ra] > self.rank[rb]: + self.parent[rb] = ra + else: + self.parent[rb] = ra + self.rank[ra] += 1 + + +class LPLogger: + def __init__(self, path: Path, *, every: int = 1, top_k: int = 10) -> None: + self.path = path + self.every = max(1, every) + self.top_k = max(1, top_k) + self.path.parent.mkdir(parents=True, exist_ok=True) + self.path.write_text("") + + def maybe_log(self, *, call_index: int, payload: Dict[str, Any]) -> None: + if call_index % self.every != 0: + return + with self.path.open("a", encoding="utf-8") as f: + f.write(json.dumps(payload, sort_keys=True) + "\n") + + +@dataclass +class SubsetLibrary: + max_subset_size: int + entries: List[SubsetLibraryEntry] + subsets_by_detector: List[List[int]] + num_subsets_by_size: Dict[int, int] + + +@dataclass +class SubsetLPSolution: + value: float + subset_u_values: Dict[int, Dict[int, float]] + num_active_subsets: int + num_components: int + num_variables: int + num_constraints: int + + +@dataclass +class SubsetLPSolverStats: + lp_calls: int = 0 + lp_total_seconds: float = 0.0 + + +class SubsetLPHeuristic: + def __init__( + self, + data: DecoderData, + subset_library: SubsetLibrary, + *, + logger: Optional[LPLogger] = None, + ) -> None: + self.data = data + self.subset_library = subset_library + self.logger = logger + self.stats = SubsetLPSolverStats() + + def reset_stats(self) -> None: + self.stats = SubsetLPSolverStats() + + def _collect_active_subset_records( + self, + active_detectors: np.ndarray, + blocked_errors: np.ndarray, + ) -> Tuple[Optional[List[ActiveSubsetRecord]], Optional[Dict[int, List[int]]]]: + active_subset_ids: set[int] = set() + for detector in np.flatnonzero(active_detectors): + active_subset_ids.update(self.subset_library.subsets_by_detector[int(detector)]) + + subset_records: List[ActiveSubsetRecord] = [] + error_to_subset_positions: Dict[int, List[int]] = defaultdict(list) + + for subset_id in sorted(active_subset_ids): + entry = self.subset_library.entries[subset_id] + target_mask = 0 + for bit_index, detector in enumerate(entry.detectors): + if active_detectors[detector]: + target_mask |= 1 << bit_index + if target_mask == 0: + continue + + available_patterns: Dict[int, Tuple[int, ...]] = {} + relevant_errors: set[int] = set() + for pattern_mask, error_indices in entry.pattern_to_errors.items(): + kept = tuple(error_index for error_index in error_indices if not blocked_errors[error_index]) + if kept: + available_patterns[pattern_mask] = kept + relevant_errors.update(kept) + + feasible_combos = tuple( + combo + for combo in entry.resolution_combos.get(target_mask, ()) + if all(pattern_mask in available_patterns for pattern_mask in combo) + ) + if not feasible_combos: + return None, None + + subset_position = len(subset_records) + subset_records.append( + ActiveSubsetRecord( + subset_id=subset_id, + detectors=entry.detectors, + size=len(entry.detectors), + target_mask=target_mask, + available_patterns=available_patterns, + feasible_combos=feasible_combos, + ) + ) + for error_index in sorted(relevant_errors): + error_to_subset_positions[error_index].append(subset_position) + + return subset_records, error_to_subset_positions + + def solve_exact( + self, + active_detectors: np.ndarray, + blocked_errors: np.ndarray, + ) -> Tuple[SubsetLPSolution, Dict[str, Any]]: + t0 = time.perf_counter() + subset_records, error_to_subset_positions = self._collect_active_subset_records( + active_detectors=active_detectors, + blocked_errors=blocked_errors, + ) + if subset_records is None: + elapsed = time.perf_counter() - t0 + self.stats.lp_total_seconds += elapsed + payload = { + "objective": INF, + "solve_seconds": elapsed, + "num_active_subsets": 0, + "num_components": 0, + "num_variables": 0, + "num_constraints": 0, + "num_active_subsets_by_size": {}, + "contribution_by_subset_size": {}, + "allocated_budget_by_subset_size": {}, + "top_subsets": [], + "structurally_infeasible": True, + } + return ( + SubsetLPSolution( + value=INF, + subset_u_values={}, + num_active_subsets=0, + num_components=0, + num_variables=0, + num_constraints=0, + ), + payload, + ) + + if not subset_records: + elapsed = time.perf_counter() - t0 + self.stats.lp_total_seconds += elapsed + payload = { + "objective": 0.0, + "solve_seconds": elapsed, + "num_active_subsets": 0, + "num_components": 0, + "num_variables": 0, + "num_constraints": 0, + "num_active_subsets_by_size": {}, + "contribution_by_subset_size": {}, + "allocated_budget_by_subset_size": {}, + "top_subsets": [], + "structurally_infeasible": False, + } + return ( + SubsetLPSolution( + value=0.0, + subset_u_values={}, + num_active_subsets=0, + num_components=0, + num_variables=0, + num_constraints=0, + ), + payload, + ) + + component_uf = UnionFind(len(subset_records)) + for subset_positions in error_to_subset_positions.values(): + for position in subset_positions[1:]: + component_uf.union(subset_positions[0], position) + component_to_subset_positions: Dict[int, List[int]] = defaultdict(list) + for subset_position in range(len(subset_records)): + component_to_subset_positions[component_uf.find(subset_position)].append(subset_position) + + total_objective = 0.0 + total_num_variables = 0 + total_num_constraints = 0 + subset_u_values: Dict[int, Dict[int, float]] = {} + contribution_by_size: Dict[int, float] = defaultdict(float) + budget_by_size: Dict[int, float] = defaultdict(float) + active_subset_count_by_size: Dict[int, int] = defaultdict(int) + top_subset_records: List[Dict[str, Any]] = [] + need_log_details = self.logger is not None + + for component_positions in component_to_subset_positions.values(): + y_var: Dict[int, int] = {} + u_var: Dict[Tuple[int, int], int] = {} + error_to_u_vars: Dict[int, List[int]] = defaultdict(list) + + next_var_index = 0 + for subset_position in component_positions: + y_var[subset_position] = next_var_index + next_var_index += 1 + for subset_position in component_positions: + record = subset_records[subset_position] + active_subset_count_by_size[record.size] += 1 + for pattern_mask, error_indices in sorted(record.available_patterns.items()): + variable_index = next_var_index + u_var[(subset_position, pattern_mask)] = variable_index + next_var_index += 1 + for error_index in error_indices: + error_to_u_vars[error_index].append(variable_index) + + row_indices: List[int] = [] + col_indices: List[int] = [] + values: List[float] = [] + rhs: List[float] = [] + + for error_index, variable_indices in sorted(error_to_u_vars.items()): + row = len(rhs) + rhs.append(float(self.data.error_costs[error_index])) + for variable_index in variable_indices: + row_indices.append(row) + col_indices.append(variable_index) + values.append(1.0) + + for subset_position in component_positions: + record = subset_records[subset_position] + y_index = y_var[subset_position] + for combo in record.feasible_combos: + row = len(rhs) + rhs.append(0.0) + row_indices.append(row) + col_indices.append(y_index) + values.append(1.0) + for pattern_mask in combo: + row_indices.append(row) + col_indices.append(u_var[(subset_position, pattern_mask)]) + values.append(-1.0) + + total_num_variables += next_var_index + total_num_constraints += len(rhs) + + a_ub = sparse.csr_matrix( + (values, (row_indices, col_indices)), + shape=(len(rhs), next_var_index), + dtype=np.float64, + ) + objective = np.zeros(next_var_index, dtype=np.float64) + for subset_position in component_positions: + objective[y_var[subset_position]] = -1.0 + + self.stats.lp_calls += 1 + result = linprog( + c=objective, + A_ub=a_ub, + b_ub=np.asarray(rhs, dtype=np.float64), + bounds=[(0.0, None)] * next_var_index, + method="highs", + ) + if not result.success: + raise RuntimeError( + f"subset detcost LP solve failed: status={result.status} message={result.message}" + ) + total_objective += float(-result.fun) + solution = np.asarray(result.x, dtype=np.float64) + + for subset_position in component_positions: + record = subset_records[subset_position] + subset_pattern_values: Dict[int, float] = {} + total_budget = 0.0 + for pattern_mask in sorted(record.available_patterns): + u_value = float(solution[u_var[(subset_position, pattern_mask)]]) + total_budget += u_value + if u_value > 1e-12: + subset_pattern_values[pattern_mask] = u_value + if subset_pattern_values: + subset_u_values[record.subset_id] = subset_pattern_values + + if need_log_details: + y_value = float(solution[y_var[subset_position]]) + contribution_by_size[record.size] += y_value + budget_by_size[record.size] += total_budget + pattern_values = [ + { + "pattern_detectors": [ + detector + for bit_index, detector in enumerate(record.detectors) + if pattern_mask & (1 << bit_index) + ], + "u": float(solution[u_var[(subset_position, pattern_mask)]]), + "num_allowed_errors": len(record.available_patterns[pattern_mask]), + } + for pattern_mask in sorted(record.available_patterns) + if solution[u_var[(subset_position, pattern_mask)]] > 1e-12 + ] + top_subset_records.append( + { + "subset_detectors": list(record.detectors), + "subset_size": record.size, + "target_active_detectors": [ + detector + for bit_index, detector in enumerate(record.detectors) + if record.target_mask & (1 << bit_index) + ], + "y": y_value, + "total_budget": total_budget, + "num_available_patterns": len(record.available_patterns), + "num_feasible_resolution_combos": len(record.feasible_combos), + "patterns": pattern_values, + } + ) + + elapsed = time.perf_counter() - t0 + self.stats.lp_total_seconds += elapsed + + if need_log_details: + top_subset_records.sort(key=lambda item: (-item["y"], -item["total_budget"], item["subset_detectors"])) + payload = { + "objective": total_objective, + "solve_seconds": elapsed, + "num_active_subsets": len(subset_records), + "num_components": len(component_to_subset_positions), + "num_variables": total_num_variables, + "num_constraints": total_num_constraints, + "num_active_subsets_by_size": { + str(size): active_subset_count_by_size[size] for size in sorted(active_subset_count_by_size) + }, + "contribution_by_subset_size": ( + {str(size): contribution_by_size[size] for size in sorted(contribution_by_size)} + if need_log_details + else {} + ), + "allocated_budget_by_subset_size": ( + {str(size): budget_by_size[size] for size in sorted(budget_by_size)} + if need_log_details + else {} + ), + "top_subsets": top_subset_records[: self.logger.top_k] if self.logger is not None else [], + "structurally_infeasible": False, + } + return ( + SubsetLPSolution( + value=total_objective, + subset_u_values=subset_u_values, + num_active_subsets=len(subset_records), + num_components=len(component_to_subset_positions), + num_variables=total_num_variables, + num_constraints=total_num_constraints, + ), + payload, + ) + + def project_from_parent( + self, + parent_solution: SubsetLPSolution, + child_active_detectors: np.ndarray, + child_blocked_errors: np.ndarray, + ) -> float: + total = 0.0 + active_subset_ids: set[int] = set() + for detector in np.flatnonzero(child_active_detectors): + active_subset_ids.update(self.subset_library.subsets_by_detector[int(detector)]) + + for subset_id in sorted(active_subset_ids): + entry = self.subset_library.entries[subset_id] + target_mask = 0 + for bit_index, detector in enumerate(entry.detectors): + if child_active_detectors[detector]: + target_mask |= 1 << bit_index + if target_mask == 0: + continue + + combos = entry.resolution_combos.get(target_mask, ()) + if not combos: + return INF + + parent_u = parent_solution.subset_u_values.get(subset_id, {}) + availability_cache: Dict[int, bool] = {} + best = INF + for combo in combos: + combo_sum = 0.0 + feasible = True + for pattern_mask in combo: + is_available = availability_cache.get(pattern_mask) + if is_available is None: + is_available = any( + not child_blocked_errors[error_index] + for error_index in entry.pattern_to_errors.get(pattern_mask, ()) + ) + availability_cache[pattern_mask] = is_available + if not is_available: + feasible = False + break + combo_sum += parent_u.get(pattern_mask, 0.0) + if feasible and combo_sum < best: + best = combo_sum + if best == INF: + return INF + total += best + + return total + + +def parse_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "+inf", "infinity", "+infinity"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("beam must be non-negative or 'inf'") + return float(value) + + +def format_indices(indices: Iterable[int], prefix: str) -> str: + items = list(indices) + if not items: + return "(none)" + return " ".join(f"{prefix}{i}" for i in items) + + +def xor_probability(p0: float, p1: float) -> float: + return p0 * (1 - p1) + (1 - p0) * p1 + + +def iter_dem_errors(dem: stim.DetectorErrorModel) -> Iterable[MergedError]: + for instruction in dem.flattened(): + if instruction.type != "error": + continue + probability = float(instruction.args_copy()[0]) + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + "This prototype assumes detector-error-model probabilities are in (0, 0.5)." + ) + detectors: set[int] = set() + observables: set[int] = set() + for target in instruction.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected DEM target: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + yield MergedError( + probability=probability, + likelihood_cost=float(-math.log(probability / (1 - probability))), + detectors=tuple(sorted(detectors)), + observables=tuple(sorted(observables)), + ) + + +def merged_errors(dem: stim.DetectorErrorModel) -> List[MergedError]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + for error in iter_dem_errors(dem): + key = (error.detectors, error.observables) + previous = errors_by_symptom.get(key) + if previous is None: + errors_by_symptom[key] = error.probability + else: + errors_by_symptom[key] = xor_probability(previous, error.probability) + + merged: List[MergedError] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + "Merged error has probability >= 0.5, which would give a non-positive cost." + ) + merged.append( + MergedError( + probability=probability, + likelihood_cost=float(-math.log(probability / (1 - probability))), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +def build_decoder_data( + dem: stim.DetectorErrorModel, + *, + merge_errors_in_dem: bool = True, +) -> DecoderData: + errors = merged_errors(dem) if merge_errors_in_dem else list(iter_dem_errors(dem)) + detector_to_errors: List[List[int]] = [[] for _ in range(dem.num_detectors)] + for ei, error in enumerate(errors): + for d in error.detectors: + detector_to_errors[d].append(ei) + return DecoderData( + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + errors=errors, + detector_to_errors=detector_to_errors, + error_costs=np.asarray([e.likelihood_cost for e in errors], dtype=np.float64), + error_detectors=[e.detectors for e in errors], + error_detector_sets=[frozenset(e.detectors) for e in errors], + error_observables=[e.observables for e in errors], + ) + + +def unpack_bit_packed_rows(bits: np.ndarray, count: int) -> np.ndarray: + return np.unpackbits(bits, bitorder="little", axis=1, count=count).astype(bool, copy=False) + + +def initial_detector_counts(data: DecoderData, active_detectors: np.ndarray) -> np.ndarray: + counts = np.zeros(len(data.errors), dtype=np.int32) + for d in np.flatnonzero(active_detectors): + for ei in data.detector_to_errors[int(d)]: + counts[ei] += 1 + return counts + + +def apply_error( + data: DecoderData, + active_detectors: np.ndarray, + active_detector_counts: np.ndarray, + error_index: int, +) -> Tuple[np.ndarray, np.ndarray]: + next_detectors = active_detectors.copy() + next_counts = active_detector_counts.copy() + for d in data.error_detectors[error_index]: + if next_detectors[d]: + next_detectors[d] = False + delta = -1 + else: + next_detectors[d] = True + delta = 1 + for other_error_index in data.detector_to_errors[d]: + next_counts[other_error_index] += delta + return next_detectors, next_counts + + +def plain_detcost_for_detector( + data: DecoderData, + detector: int, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, +) -> float: + best = INF + for ei in data.detector_to_errors[detector]: + if blocked_errors[ei]: + continue + count = int(active_detector_counts[ei]) + assert count > 0 + candidate = float(data.error_costs[ei]) / count + if candidate < best: + best = candidate + return best + + +def plain_detcost_heuristic( + data: DecoderData, + active_detectors: np.ndarray, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, +) -> float: + total = 0.0 + for d in np.flatnonzero(active_detectors): + det_cost = plain_detcost_for_detector( + data=data, + detector=int(d), + blocked_errors=blocked_errors, + active_detector_counts=active_detector_counts, + ) + if det_cost == INF: + return INF + total += det_cost + return total + + +def compute_minimal_resolution_combos( + available_pattern_masks: Iterable[int], + subset_size: int, +) -> Dict[int, Tuple[Tuple[int, ...], ...]]: + patterns = tuple(sorted(set(available_pattern_masks))) + combos_by_target: Dict[int, List[Tuple[int, ...]]] = { + target: [] for target in range(1, 1 << subset_size) + } + for r in range(1, min(len(patterns), subset_size) + 1): + for combo in itertools.combinations(patterns, r): + target_mask = 0 + for pattern_mask in combo: + target_mask ^= pattern_mask + if target_mask == 0: + continue + combo_set = set(combo) + existing = combos_by_target[target_mask] + keep = True + survivors: List[Tuple[int, ...]] = [] + for old_combo in existing: + old_set = set(old_combo) + if combo_set.issuperset(old_set): + keep = False + survivors.append(old_combo) + elif old_set.issuperset(combo_set): + continue + else: + survivors.append(old_combo) + if keep: + survivors.append(combo) + survivors.sort(key=lambda x: (len(x), x)) + combos_by_target[target_mask] = survivors + return { + target_mask: tuple(combos) + for target_mask, combos in combos_by_target.items() + if combos + } + + +def build_subset_library(data: DecoderData, max_subset_size: int) -> SubsetLibrary: + library_keys: set[Tuple[int, ...]] = set() + if max_subset_size >= 1: + for detector in range(data.num_detectors): + library_keys.add((detector,)) + + for detectors in data.error_detectors: + limit = min(max_subset_size, len(detectors)) + for subset_size in range(1, limit + 1): + for subset_detectors in itertools.combinations(detectors, subset_size): + library_keys.add(tuple(subset_detectors)) + + subsets_by_detector: List[List[int]] = [[] for _ in range(data.num_detectors)] + entries: List[SubsetLibraryEntry] = [] + num_subsets_by_size: Dict[int, int] = defaultdict(int) + + for subset_id, subset_detectors in enumerate(sorted(library_keys, key=lambda t: (len(t), t))): + pattern_to_errors: Dict[int, List[int]] = defaultdict(list) + for error_index, detector_set in enumerate(data.error_detector_sets): + pattern_mask = 0 + for bit_index, detector in enumerate(subset_detectors): + if detector in detector_set: + pattern_mask |= 1 << bit_index + if pattern_mask != 0: + pattern_to_errors[pattern_mask].append(error_index) + frozen_pattern_to_errors = { + pattern_mask: tuple(error_indices) + for pattern_mask, error_indices in pattern_to_errors.items() + } + entry = SubsetLibraryEntry( + subset_id=subset_id, + detectors=subset_detectors, + pattern_to_errors=frozen_pattern_to_errors, + resolution_combos=compute_minimal_resolution_combos( + available_pattern_masks=frozen_pattern_to_errors.keys(), + subset_size=len(subset_detectors), + ), + ) + entries.append(entry) + num_subsets_by_size[len(subset_detectors)] += 1 + for detector in subset_detectors: + subsets_by_detector[detector].append(subset_id) + + return SubsetLibrary( + max_subset_size=max_subset_size, + entries=entries, + subsets_by_detector=subsets_by_detector, + num_subsets_by_size=dict(sorted(num_subsets_by_size.items())), + ) + + +def detectors_from_solution(data: DecoderData, activated_errors: Sequence[int]) -> np.ndarray: + detectors = np.zeros(data.num_detectors, dtype=bool) + for error_index in activated_errors: + for detector in data.error_detectors[error_index]: + detectors[detector] ^= True + return detectors + + +def observables_from_solution(data: DecoderData, activated_errors: Sequence[int]) -> np.ndarray: + observables = np.zeros(data.num_observables, dtype=bool) + for error_index in activated_errors: + for observable in data.error_observables[error_index]: + observables[observable] ^= True + return observables + + +def decode( + data: DecoderData, + detections: np.ndarray, + *, + det_beam: float = INF, + opt_subset_solver: Optional[SubsetLPHeuristic] = None, + verbose_search: bool = False, +) -> DecodeResult: + start_time = time.perf_counter() + if opt_subset_solver is not None: + opt_subset_solver.reset_stats() + + heuristic_calls = 0 + plain_heuristic_calls = 0 + projection_heuristic_calls = 0 + exact_refinement_calls = 0 + lp_reinserts = 0 + projected_nodes_generated = 0 + projected_nodes_refined = 0 + total_lp_refinement_gain = 0.0 + max_lp_refinement_gain = 0.0 + + initial_active_detectors = np.asarray(detections, dtype=bool).copy() + initial_counts = initial_detector_counts(data, initial_active_detectors) + initial_blocked = np.zeros(len(data.errors), dtype=bool) + heuristic_calls += 1 + plain_heuristic_calls += 1 + initial_heuristic = plain_detcost_heuristic( + data=data, + active_detectors=initial_active_detectors, + blocked_errors=initial_blocked, + active_detector_counts=initial_counts, + ) + if initial_heuristic == INF: + raise RuntimeError("Initial residual syndrome is infeasible under the current pruning rule.") + + initial_state = SearchState( + activated_errors=(), + blocked_errors=initial_blocked, + active_detectors=initial_active_detectors, + active_detector_counts=initial_counts, + path_cost=0.0, + heuristic_cost=initial_heuristic, + heuristic_source="plain", + exact_refined=(opt_subset_solver is None), + lp_solution=None, + ) + + priority_queue: List[Tuple[float, int, int, SearchState]] = [] + push_counter = 0 + initial_num_dets = int(initial_active_detectors.sum()) + heapq.heappush( + priority_queue, + (initial_state.path_cost + initial_state.heuristic_cost, initial_num_dets, push_counter, initial_state), + ) + push_counter += 1 + + num_pq_pushed = 1 + num_nodes_popped = 0 + max_queue_size = 1 + min_num_dets = initial_num_dets + max_num_dets = INF if det_beam == INF else min_num_dets + det_beam + + heuristic_name = ( + f"opt_subset_detcost_size_{opt_subset_solver.subset_library.max_subset_size}_lazy_projection" + if opt_subset_solver is not None + else "plain_detcost" + ) + + while priority_queue: + max_queue_size = max(max_queue_size, len(priority_queue)) + f_cost, num_dets, _, state = heapq.heappop(priority_queue) + num_nodes_popped += 1 + + if num_dets > max_num_dets: + continue + + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = INF if det_beam == INF else min_num_dets + det_beam + + if verbose_search: + print( + f"nodes_popped={num_nodes_popped} len(pq)={len(priority_queue)} " + f"lp_calls={0 if opt_subset_solver is None else opt_subset_solver.stats.lp_calls} " + f"lp_reinserts={lp_reinserts} proj_generated={projected_nodes_generated} " + f"proj_refined={projected_nodes_refined} " + f"proj_unrefined_so_far={projected_nodes_generated - projected_nodes_refined} " + f"active_dets={num_dets} beam_max={max_num_dets} depth={len(state.activated_errors)} " + f"f={f_cost:.12g} g={state.path_cost:.12g} h={state.heuristic_cost:.12g} " + f"h_source={state.heuristic_source} exact_refined={state.exact_refined}" + ) + + if num_dets == 0: + elapsed_seconds = time.perf_counter() - start_time + lp_calls = 0 if opt_subset_solver is None else opt_subset_solver.stats.lp_calls + lp_total_seconds = 0.0 if opt_subset_solver is None else opt_subset_solver.stats.lp_total_seconds + return DecodeResult( + activated_errors=state.activated_errors, + path_cost=state.path_cost, + stats=DecodeStats( + num_pq_pushed=num_pq_pushed, + num_nodes_popped=num_nodes_popped, + max_queue_size=max_queue_size, + heuristic_calls=heuristic_calls, + plain_heuristic_calls=plain_heuristic_calls, + projection_heuristic_calls=projection_heuristic_calls, + exact_refinement_calls=exact_refinement_calls, + lp_calls=lp_calls, + lp_reinserts=lp_reinserts, + projected_nodes_generated=projected_nodes_generated, + projected_nodes_refined=projected_nodes_refined, + projected_nodes_unrefined_at_finish=projected_nodes_generated - projected_nodes_refined, + total_lp_refinement_gain=total_lp_refinement_gain, + max_lp_refinement_gain=max_lp_refinement_gain, + lp_total_seconds=lp_total_seconds, + elapsed_seconds=elapsed_seconds, + heuristic_name=heuristic_name, + ), + ) + + if opt_subset_solver is not None and not state.exact_refined: + heuristic_calls += 1 + exact_refinement_calls += 1 + previous_h = state.heuristic_cost + previous_source = state.heuristic_source + exact_solution, exact_payload = opt_subset_solver.solve_exact( + active_detectors=state.active_detectors, + blocked_errors=state.blocked_errors, + ) + exact_h = exact_solution.value + reinserted = False + discarded = False + + if exact_h == INF: + discarded = True + if previous_source == "projected": + projected_nodes_refined += 1 + else: + if exact_h + 1e-7 < previous_h: + raise AssertionError( + f"Exact subset LP lower bound {exact_h} is below stored {previous_source} lower bound {previous_h}." + ) + delta = exact_h - previous_h + total_lp_refinement_gain += delta + max_lp_refinement_gain = max(max_lp_refinement_gain, delta) + state.heuristic_cost = exact_h + state.heuristic_source = "exact" + state.exact_refined = True + state.lp_solution = exact_solution + if previous_source == "projected": + projected_nodes_refined += 1 + if delta > HEURISTIC_EPS: + reinserted = True + lp_reinserts += 1 + heapq.heappush( + priority_queue, + (state.path_cost + state.heuristic_cost, num_dets, push_counter, state), + ) + push_counter += 1 + + if opt_subset_solver.logger is not None: + payload = dict(exact_payload) + payload.update( + { + "call_index": exact_refinement_calls, + "phase": "exact_refinement", + "depth": len(state.activated_errors), + "nodes_popped": num_nodes_popped, + "path_cost": state.path_cost, + "active_detector_count": num_dets, + "approx_h": previous_h, + "exact_h": exact_h, + "delta": INF if exact_h == INF else exact_h - previous_h, + "heuristic_source_before": previous_source, + "reinserted": reinserted, + "discarded": discarded, + } + ) + opt_subset_solver.logger.maybe_log(call_index=exact_refinement_calls, payload=payload) + + if verbose_search: + delta_text = "INF" if exact_h == INF else f"{exact_h - previous_h:.12g}" + exact_text = "INF" if exact_h == INF else f"{exact_h:.12g}" + print( + f" lp_refine approx_h={previous_h:.12g} exact_h={exact_text} delta={delta_text} " + f"vars={exact_solution.num_variables} constraints={exact_solution.num_constraints} " + f"active_subsets={exact_solution.num_active_subsets} comps={exact_solution.num_components} " + f"reinserted={reinserted} discarded={discarded}" + ) + + if discarded or reinserted: + continue + + min_detector = int(np.flatnonzero(state.active_detectors)[0]) + blocked_prefix = state.blocked_errors.copy() + children_generated = 0 + children_projected = 0 + children_beam_pruned = 0 + children_infeasible = 0 + + for error_index in data.detector_to_errors[min_detector]: + blocked_prefix[error_index] = True + if state.blocked_errors[error_index]: + continue + + child_active_detectors, child_active_counts = apply_error( + data=data, + active_detectors=state.active_detectors, + active_detector_counts=state.active_detector_counts, + error_index=error_index, + ) + child_num_dets = int(child_active_detectors.sum()) + if child_num_dets > max_num_dets: + children_beam_pruned += 1 + continue + + child_blocked = blocked_prefix.copy() + child_path_cost = state.path_cost + float(data.error_costs[error_index]) + + if opt_subset_solver is None: + heuristic_calls += 1 + plain_heuristic_calls += 1 + child_heuristic = plain_detcost_heuristic( + data=data, + active_detectors=child_active_detectors, + blocked_errors=child_blocked, + active_detector_counts=child_active_counts, + ) + child_source = "plain" + child_exact_refined = True + child_lp_solution = None + else: + if state.lp_solution is None: + raise AssertionError("Subset-LP projection requires an exact-refined parent solution.") + heuristic_calls += 1 + projection_heuristic_calls += 1 + projected_nodes_generated += 1 + children_projected += 1 + child_heuristic = opt_subset_solver.project_from_parent( + parent_solution=state.lp_solution, + child_active_detectors=child_active_detectors, + child_blocked_errors=child_blocked, + ) + child_source = "projected" + child_exact_refined = False + child_lp_solution = None + + if child_heuristic == INF: + children_infeasible += 1 + continue + + child_state = SearchState( + activated_errors=state.activated_errors + (error_index,), + blocked_errors=child_blocked, + active_detectors=child_active_detectors, + active_detector_counts=child_active_counts, + path_cost=child_path_cost, + heuristic_cost=child_heuristic, + heuristic_source=child_source, + exact_refined=child_exact_refined, + lp_solution=child_lp_solution, + ) + heapq.heappush( + priority_queue, + (child_path_cost + child_heuristic, child_num_dets, push_counter, child_state), + ) + push_counter += 1 + num_pq_pushed += 1 + children_generated += 1 + + if verbose_search: + print( + f" expanded children_generated={children_generated} children_projected={children_projected} " + f"beam_pruned={children_beam_pruned} infeasible={children_infeasible}" + ) + + raise RuntimeError("Decoding failed to find any completion.") + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Prototype A* decoder for Stim detector error models. " + "Supports plain detcost and lazy subset-based LP lower bounds." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a Stim circuit file.") + parser.add_argument( + "--shot", + type=int, + default=0, + help="Zero-based sampled shot index to decode.", + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=100, + help="Number of shots to sample before selecting --shot.", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Seed passed to stim.compile_detector_sampler(...).sample(...).", + ) + parser.add_argument( + "--det-beam", + type=parse_beam, + default=INF, + help="Beam cutoff on the residual detector count. Use an integer or 'inf'.", + ) + parser.add_argument( + "--opt-subset-detcost-size", + type=int, + default=0, + help=( + "Use the lazy subset-based LP heuristic with library subsets of size at most N. " + "Use 0 for plain detcost, 1 for the optimal singleton LP, etc." + ), + ) + parser.add_argument( + "--merge-errors", + action=argparse.BooleanOptionalAction, + default=True, + help="Merge indistinguishable DEM errors before decoding (default: enabled).", + ) + parser.add_argument( + "--show-shot-detectors", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the sampled shot's active detector IDs before decoding.", + ) + parser.add_argument( + "--show-error-indices", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the activated error indices in the final decoding.", + ) + parser.add_argument( + "--verbose-search", + action="store_true", + help="Print per-node search diagnostics.", + ) + parser.add_argument( + "--lp-log-path", + type=Path, + default=None, + help="Optional JSONL file for logging details of each exact subset-LP refinement.", + ) + parser.add_argument( + "--lp-log-top-k", + type=int, + default=10, + help="When logging exact LP refinements, include at most this many top subsets.", + ) + parser.add_argument( + "--lp-log-every", + type=int, + default=1, + help="When logging exact LP refinements, only write every k-th refinement.", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + if args.opt_subset_detcost_size < 0: + parser.error("--opt-subset-detcost-size must be non-negative.") + if args.lp_log_every <= 0: + parser.error("--lp-log-every must be positive.") + if args.lp_log_top_k <= 0: + parser.error("--lp-log-top-k must be positive.") + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + data = build_decoder_data(dem, merge_errors_in_dem=args.merge_errors) + + subset_library = None + subset_solver = None + if args.opt_subset_detcost_size > 0: + subset_library = build_subset_library(data, args.opt_subset_detcost_size) + lp_logger = None + if args.lp_log_path is not None: + lp_logger = LPLogger( + args.lp_log_path, + every=args.lp_log_every, + top_k=args.lp_log_top_k, + ) + subset_solver = SubsetLPHeuristic(data, subset_library, logger=lp_logger) + + dets_packed, obs_packed = circuit.compile_detector_sampler(seed=args.seed).sample( + shots=args.sample_num_shots, + separate_observables=True, + bit_packed=True, + ) + detections = unpack_bit_packed_rows(dets_packed, count=dem.num_detectors) + observables = unpack_bit_packed_rows(obs_packed, count=dem.num_observables) + + if args.shot >= detections.shape[0]: + parser.error(f"--shot={args.shot} is out of range for {detections.shape[0]} sampled shots.") + + shot_detections = detections[args.shot] + shot_observables = observables[args.shot] if observables.size else np.zeros(0, dtype=bool) + + print(f"circuit = {args.circuit}") + print( + "heuristic = " + + ( + "plain_detcost" + if subset_solver is None + else f"opt_subset_detcost_size_{subset_library.max_subset_size}_lazy_projection" + ) + ) + print(f"shot = {args.shot}") + print(f"sample_num_shots = {args.sample_num_shots}") + print(f"num_detectors = {data.num_detectors}") + print(f"num_observables = {data.num_observables}") + print(f"num_errors = {len(data.errors)}") + print(f"beam = {args.det_beam}") + if subset_library is not None: + print(f"subset_library_size = {len(subset_library.entries)}") + print( + "subset_library_by_size = " + + ", ".join( + f"{size}:{count}" for size, count in subset_library.num_subsets_by_size.items() + ) + ) + if args.show_shot_detectors: + print(f"shot_detectors = {format_indices(np.flatnonzero(shot_detections), 'D')}") + + result = decode( + data=data, + detections=shot_detections, + det_beam=args.det_beam, + opt_subset_solver=subset_solver, + verbose_search=args.verbose_search, + ) + + predicted_observables = observables_from_solution(data, result.activated_errors) + reproduced_detectors = detectors_from_solution(data, result.activated_errors) + if not np.array_equal(reproduced_detectors, shot_detections): + raise AssertionError("Decoded error set does not reproduce the shot's syndrome.") + + print(f"solution_size = {len(result.activated_errors)}") + print(f"solution_cost = {result.path_cost:.12g}") + if args.show_error_indices: + print(f"activated_errors = {format_indices(result.activated_errors, 'E')}") + print(f"predicted_observables = {format_indices(np.flatnonzero(predicted_observables), 'L')}") + print(f"sample_observables = {format_indices(np.flatnonzero(shot_observables), 'L')}") + print(f"observables_match = {bool(np.array_equal(predicted_observables, shot_observables))}") + print(f"num_pq_pushed = {result.stats.num_pq_pushed}") + print(f"num_nodes_popped = {result.stats.num_nodes_popped}") + print(f"max_queue_size = {result.stats.max_queue_size}") + print(f"heuristic_calls = {result.stats.heuristic_calls}") + print(f"plain_heuristic_calls = {result.stats.plain_heuristic_calls}") + print(f"projection_heuristic_calls = {result.stats.projection_heuristic_calls}") + print(f"exact_refinement_calls = {result.stats.exact_refinement_calls}") + print(f"lp_calls = {result.stats.lp_calls}") + print(f"lp_reinserts = {result.stats.lp_reinserts}") + print(f"projected_nodes_generated = {result.stats.projected_nodes_generated}") + print(f"projected_nodes_refined = {result.stats.projected_nodes_refined}") + print(f"projected_nodes_unrefined_at_finish = {result.stats.projected_nodes_unrefined_at_finish}") + print(f"total_lp_refinement_gain = {result.stats.total_lp_refinement_gain:.12g}") + print(f"max_lp_refinement_gain = {result.stats.max_lp_refinement_gain:.12g}") + print(f"lp_total_seconds = {result.stats.lp_total_seconds:.6f}") + print(f"elapsed_seconds = {result.stats.elapsed_seconds:.6f}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/astar_qec_inactive_lp.py b/src/py/astar/astar_qec_inactive_lp.py new file mode 100644 index 0000000..d9e2c36 --- /dev/null +++ b/src/py/astar/astar_qec_inactive_lp.py @@ -0,0 +1,1399 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder for Stim circuits using greedy singleton-budget heuristics. + +This version keeps the same Stim-facing API as the earlier greedy prototype but +adds lazy reinsertion / parent-y projection, in the same spirit as the lazy +optimal-singleton prototype: + + * nodes are seeded with a cheap feasible lower bound; + * when a node is popped, the selected heuristic is evaluated on that node; + * if the refined heuristic raises the node key, the node is reinserted; + * expanded nodes project their current feasible y-prices onto children; + * optionally, the projected child bound is maxed with plain detcost. + +Supported heuristic choices: + plain original detector-wise feasible point + asc_deg zero-start saturation ordered by ascending detector degree + desc_plain zero-start saturation ordered by descending plain y_d + plain_sweep start from plain, then one descending saturation sweep + best_of_two max(plain_sweep, asc_deg) + best_of_three max(plain_sweep, asc_deg, desc_plain) + exact_lp exact optimal singleton LP lower bound + exact_lp_plus_inactive + exact LP lower bound with extra inactive-detector no-one-hot constraints + +When --lazy-reinsert-heuristics is enabled (the default), the root is seeded by +plain detcost and only popped nodes are refined with the selected heuristic. +This works directly for the support-only heuristics because each returns a +feasible singleton-budget vector y, and projecting that y to a child by +keeping prices on detectors that remain active and zeroing newly active +detectors is still a feasible child singleton-budget point. For +exact_lp_plus_inactive, the refined LP optimum is not directly projectable to a +child, so lazy mode keeps the current projectable singleton-budget prices for +child seeding and uses the tightened LP only when refining popped nodes. +""" + +from __future__ import annotations + +import argparse +import heapq +import math +import time +from collections import defaultdict, deque +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +import stim +from scipy.optimize import linprog +from scipy.sparse import csr_matrix + +INF = float("inf") +HEURISTIC_EPS = 1e-9 + + +@dataclass(frozen=True) +class ErrorRecord: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class SupportData: + active_detectors: List[int] + supports: List[Tuple[Tuple[int, ...], float]] + incident: Dict[int, List[int]] + + +@dataclass +class SearchState: + errs: np.ndarray + blocked_errs: np.ndarray + dets: np.ndarray + det_counts: np.ndarray + g_cost: float + h_cost: float + h_source: str + refined: bool + y_prices: Optional[np.ndarray] + + +@dataclass +class DecodeResult: + success: bool + errs: np.ndarray + residual_dets: np.ndarray + cost: float + nodes_pushed: int + nodes_popped: int + max_queue_size: int + heuristic_calls: int + plain_heuristic_calls: int + projection_heuristic_calls: int + refinement_calls: int + lp_calls: int + reinserts: int + projected_nodes_generated: int + projected_nodes_refined: int + projected_nodes_unrefined_at_finish: int + total_refinement_gain: float + max_refinement_gain: float + elapsed_seconds: float + + +class UnionFind: + def __init__(self, n: int) -> None: + self.parent = list(range(n)) + self.rank = [0] * n + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra = self.find(a) + rb = self.find(b) + if ra == rb: + return + if self.rank[ra] < self.rank[rb]: + self.parent[ra] = rb + elif self.rank[ra] > self.rank[rb]: + self.parent[rb] = ra + else: + self.parent[rb] = ra + self.rank[ra] += 1 + + +def xor_probability(p0: float, p1: float) -> float: + return p0 * (1.0 - p1) + (1.0 - p0) * p1 + + +def iter_dem_errors_from_dem(dem: stim.DetectorErrorModel) -> Iterable[ErrorRecord]: + for instruction in dem.flattened(): + if instruction.type != "error": + continue + probability = float(instruction.args_copy()[0]) + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + f"Expected flattened error probabilities in (0, 0.5), got {probability}." + ) + + detectors: set[int] = set() + observables: set[int] = set() + for target in instruction.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected DEM target: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + + yield ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=tuple(sorted(detectors)), + observables=tuple(sorted(observables)), + ) + + +def merged_errors_from_dem(dem: stim.DetectorErrorModel) -> List[ErrorRecord]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + for error in iter_dem_errors_from_dem(dem): + key = (error.detectors, error.observables) + p_old = errors_by_symptom.get(key) + if p_old is None: + p_new = error.probability + else: + p_new = xor_probability(p_old, error.probability) + errors_by_symptom[key] = p_new + + merged: List[ErrorRecord] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + f"Merged error has probability >= 0.5 ({probability}); cannot assign positive cost." + ) + merged.append( + ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +class GreedySingletonHeuristicDecoder: + def __init__( + self, + errors: Sequence[ErrorRecord], + num_detectors: int, + num_observables: int, + *, + heuristic: str = "best_of_two", + respect_blocked_errors_in_heuristic: bool = True, + lazy_reinsert_heuristics: bool = True, + projection_combine_max_plain: bool = True, + verbose_search: bool = False, + ) -> None: + self.errors = list(errors) + self.num_errors = len(self.errors) + self.num_detectors = int(num_detectors) + self.num_observables = int(num_observables) + self.heuristic_name = heuristic + self.respect_blocked_errors_in_heuristic = respect_blocked_errors_in_heuristic + self.lazy_reinsert_heuristics = lazy_reinsert_heuristics + self.projection_combine_max_plain = projection_combine_max_plain + self.verbose_search = verbose_search + + self.probabilities = np.array([err.probability for err in self.errors], dtype=np.float64) + self.weights = np.array([err.likelihood_cost for err in self.errors], dtype=np.float64) + self.error_detectors: List[Tuple[int, ...]] = [tuple(err.detectors) for err in self.errors] + self.error_observables: List[Tuple[int, ...]] = [tuple(err.observables) for err in self.errors] + + d2e_lists: List[List[int]] = [[] for _ in range(self.num_detectors)] + for ei, dets in enumerate(self.error_detectors): + for d in dets: + d2e_lists[d].append(ei) + self.d2e: List[np.ndarray] = [np.array(v, dtype=np.int32) for v in d2e_lists] + + self.reset_stats() + + def reset_stats(self) -> None: + self.heuristic_calls = 0 + self.plain_heuristic_calls = 0 + self.projection_heuristic_calls = 0 + self.refinement_calls = 0 + self.lp_calls = 0 + self.reinserts = 0 + self.projected_nodes_generated = 0 + self.projected_nodes_refined = 0 + self.total_refinement_gain = 0.0 + self.max_refinement_gain = 0.0 + + @property + def mode_name(self) -> str: + if self.heuristic_name == "plain": + return "plain" + if self.lazy_reinsert_heuristics: + suffix = "-lazy-projection" + if self.projection_combine_max_plain: + suffix += "-maxplain" + return f"{self.heuristic_name}{suffix}" + return self.heuristic_name + + @staticmethod + def heuristic_has_projectable_prices(name: str) -> bool: + return name != "exact_lp_plus_inactive" + + def _available_errors(self, errs: np.ndarray, blocked_errs: np.ndarray) -> np.ndarray: + available = ~errs + if self.respect_blocked_errors_in_heuristic: + available &= ~blocked_errs + return available + + def _has_cover_for_all_active_detectors(self, dets: np.ndarray, available_errs: np.ndarray) -> bool: + for d in np.flatnonzero(dets): + found = False + for ei in self.d2e[int(d)]: + if available_errs[int(ei)]: + found = True + break + if not found: + return False + return True + + def build_support_data(self, active_dets: np.ndarray, available_errs: np.ndarray) -> SupportData: + active_list = sorted(map(int, np.flatnonzero(active_dets))) + incident: Dict[int, List[int]] = {d: [] for d in active_list} + support_to_weight: Dict[Tuple[int, ...], float] = {} + + for ei in np.flatnonzero(available_errs): + ei = int(ei) + support = tuple(d for d in self.error_detectors[ei] if active_dets[d]) + if not support: + continue + weight = float(self.weights[ei]) + old = support_to_weight.get(support) + if old is None or weight < old: + support_to_weight[support] = weight + + supports = list(support_to_weight.items()) + for i, (support, _weight) in enumerate(supports): + for d in support: + if d in incident: + incident[d].append(i) + + return SupportData(active_detectors=active_list, supports=supports, incident=incident) + + def _check_coverage(self, support_data: SupportData) -> bool: + return all(len(support_data.incident[d]) > 0 for d in support_data.active_detectors) + + def plain_detcost_from_counts( + self, + dets: np.ndarray, + available_errs: np.ndarray, + det_counts: np.ndarray, + ) -> Tuple[float, Optional[np.ndarray]]: + self.heuristic_calls += 1 + self.plain_heuristic_calls += 1 + active = np.flatnonzero(dets) + if active.size == 0: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + + y = np.zeros(self.num_detectors, dtype=np.float64) + total = 0.0 + for d in active: + best = INF + for ei in self.d2e[int(d)]: + ei = int(ei) + if not available_errs[ei]: + continue + count = int(det_counts[ei]) + assert count > 0 + value = self.weights[ei] / count + if value < best: + best = value + if math.isinf(best): + return INF, None + y[int(d)] = best + total += best + return total, y + + def heuristic_plain(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + if not support_data.active_detectors: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + y = np.zeros(self.num_detectors, dtype=np.float64) + for d in support_data.active_detectors: + best = INF + for i in support_data.incident[d]: + support, weight = support_data.supports[i] + best = min(best, weight / len(support)) + y[d] = best + return float(y[support_data.active_detectors].sum()), y + + def heuristic_saturation_zero(self, support_data: SupportData, *, order_kind: str) -> Tuple[float, Optional[np.ndarray]]: + if not support_data.active_detectors: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + + slack = np.array([weight for _support, weight in support_data.supports], dtype=np.float64) + y = np.zeros(self.num_detectors, dtype=np.float64) + + if order_kind == "asc_deg": + order = sorted(support_data.active_detectors, key=lambda d: (len(support_data.incident[d]), d)) + elif order_kind == "desc_plain": + _plain_value, y_plain = self.heuristic_plain(support_data) + if y_plain is None: + return INF, None + order = sorted(support_data.active_detectors, key=lambda d: (y_plain[d], d), reverse=True) + else: + raise ValueError(f"Unknown order_kind={order_kind!r}") + + for d in order: + value = min(slack[i] for i in support_data.incident[d]) + if value < 0: + value = 0.0 + y[d] = value + for i in support_data.incident[d]: + slack[i] -= value + return float(y[support_data.active_detectors].sum()), y + + def heuristic_plain_sweep(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + plain_value, y = self.heuristic_plain(support_data) + if y is None: + return INF, None + order = sorted(support_data.active_detectors, key=lambda d: (y[d], d), reverse=True) + for d in order: + max_feasible = min( + weight - sum(y[dd] for dd in support if dd != d) + for support, weight in support_data.supports + if d in support + ) + if max_feasible > y[d]: + y[d] = max_feasible + return float(y[support_data.active_detectors].sum()), y + + def heuristic_exact_lp(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + active = support_data.active_detectors + if not active: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + + detector_index = {d: i for i, d in enumerate(active)} + uf = UnionFind(len(active)) + for support, _weight in support_data.supports: + if len(support) > 1: + a = detector_index[support[0]] + for d in support[1:]: + uf.union(a, detector_index[d]) + + components: Dict[int, List[int]] = defaultdict(list) + for d in active: + components[uf.find(detector_index[d])].append(d) + + y = np.zeros(self.num_detectors, dtype=np.float64) + total = 0.0 + for component in components.values(): + component_set = set(component) + local = {d: i for i, d in enumerate(sorted(component))} + component_supports: List[Tuple[Tuple[int, ...], float]] = [] + for support, weight in support_data.supports: + if support[0] in component_set: + component_supports.append((tuple(local[d] for d in support), weight)) + + rows: List[int] = [] + cols: List[int] = [] + data: List[float] = [] + rhs: List[float] = [] + for r, (support, weight) in enumerate(component_supports): + rhs.append(weight) + for c in support: + rows.append(r) + cols.append(c) + data.append(1.0) + + a_ub = csr_matrix( + (data, (rows, cols)), + shape=(len(component_supports), len(component)), + dtype=np.float64, + ) + self.lp_calls += 1 + result = linprog( + c=-np.ones(len(component), dtype=np.float64), + A_ub=a_ub, + b_ub=np.array(rhs, dtype=np.float64), + bounds=[(0.0, None)] * len(component), + method="highs", + ) + if not result.success: + return INF, None + total += -float(result.fun) + for d, value in zip(sorted(component), result.x): + y[d] = float(value) + return float(total), y + + + def _reachable_available_components( + self, + dets: np.ndarray, + available_errs: np.ndarray, + ) -> List[Tuple[List[int], List[int], List[int]]]: + active = sorted(map(int, np.flatnonzero(dets))) + if not active: + return [] + + det_visited = np.zeros(self.num_detectors, dtype=bool) + err_visited = np.zeros(self.num_errors, dtype=bool) + components: List[Tuple[List[int], List[int], List[int]]] = [] + + for seed in active: + if det_visited[seed]: + continue + det_visited[seed] = True + queue: deque[int] = deque([seed]) + component_dets: List[int] = [] + component_errs: List[int] = [] + while queue: + d = queue.popleft() + component_dets.append(d) + for ei in self.d2e[d]: + ei = int(ei) + if not available_errs[ei] or err_visited[ei]: + continue + err_visited[ei] = True + component_errs.append(ei) + for dd in self.error_detectors[ei]: + dd = int(dd) + if not det_visited[dd]: + det_visited[dd] = True + queue.append(dd) + + component_active = [d for d in component_dets if dets[d]] + if not component_active: + continue + component_inactive = [d for d in component_dets if not dets[d]] + components.append((component_active, component_inactive, component_errs)) + + return components + + def _solve_component_exact_lp_plus_inactive( + self, + component_active: Sequence[int], + component_inactive: Sequence[int], + component_errors: Sequence[int], + ) -> float: + if not component_active: + return 0.0 + if not component_errors: + return INF + + local_errors = list(component_errors) + det_to_local_errors: Dict[int, List[int]] = defaultdict(list) + for local_ei, ei in enumerate(local_errors): + for d in self.error_detectors[ei]: + det_to_local_errors[int(d)].append(local_ei) + + active_set = set(component_active) + inactive_set = set(component_inactive) + deg: Dict[int, int] = { + d: len(det_to_local_errors.get(d, [])) + for d in active_set | inactive_set + } + alive = np.ones(len(local_errors), dtype=bool) + queue: deque[int] = deque(d for d in component_inactive if deg.get(d, 0) == 1) + + while queue: + d = queue.popleft() + if deg.get(d, 0) != 1: + continue + forced_local = next( + (local_ei for local_ei in det_to_local_errors.get(d, []) if alive[local_ei]), + None, + ) + if forced_local is None: + deg[d] = 0 + continue + if not alive[forced_local]: + continue + alive[forced_local] = False + for dd in self.error_detectors[local_errors[forced_local]]: + dd = int(dd) + if dd not in deg or deg[dd] <= 0: + continue + deg[dd] -= 1 + if dd in active_set and deg[dd] == 0: + return INF + if dd in inactive_set and deg[dd] == 1: + queue.append(dd) + + for d in component_active: + if deg.get(d, 0) <= 0: + return INF + + reduced_errors = [ei for local_ei, ei in enumerate(local_errors) if alive[local_ei]] + if not reduced_errors: + return INF + + local_error_index = {ei: local_ei for local_ei, ei in enumerate(reduced_errors)} + det_to_reduced_errors: Dict[int, List[int]] = defaultdict(list) + for ei in reduced_errors: + local_ei = local_error_index[ei] + for d in self.error_detectors[ei]: + d = int(d) + if deg.get(d, 0) > 0: + det_to_reduced_errors[d].append(local_ei) + + inactive_with_incidence = [d for d in component_inactive if deg.get(d, 0) > 0] + num_x = len(reduced_errors) + num_s = len(inactive_with_incidence) + num_vars = num_x + num_s + c = np.zeros(num_vars, dtype=np.float64) + c[:num_x] = self.weights[reduced_errors] + + inactive_col = {d: num_x + i for i, d in enumerate(inactive_with_incidence)} + + ub_rows: List[int] = [] + ub_cols: List[int] = [] + ub_data: List[float] = [] + b_ub: List[float] = [] + ub_r = 0 + + for d in component_active: + incident = det_to_reduced_errors.get(d, []) + if not incident: + return INF + for local_ei in incident: + ub_rows.append(ub_r) + ub_cols.append(local_ei) + ub_data.append(-1.0) + b_ub.append(-1.0) + ub_r += 1 + + for d in inactive_with_incidence: + s_col = inactive_col[d] + for local_ei in det_to_reduced_errors[d]: + ub_rows.extend([ub_r, ub_r]) + ub_cols.extend([local_ei, s_col]) + ub_data.extend([2.0, -1.0]) + b_ub.append(0.0) + ub_r += 1 + + eq_rows: List[int] = [] + eq_cols: List[int] = [] + eq_data: List[float] = [] + b_eq: List[float] = [] + eq_r = 0 + + for d in inactive_with_incidence: + s_col = inactive_col[d] + eq_rows.append(eq_r) + eq_cols.append(s_col) + eq_data.append(1.0) + for local_ei in det_to_reduced_errors[d]: + eq_rows.append(eq_r) + eq_cols.append(local_ei) + eq_data.append(-1.0) + b_eq.append(0.0) + eq_r += 1 + + a_ub = None + if ub_r > 0: + a_ub = csr_matrix( + (ub_data, (ub_rows, ub_cols)), + shape=(ub_r, num_vars), + dtype=np.float64, + ) + + a_eq = None + if eq_r > 0: + a_eq = csr_matrix( + (eq_data, (eq_rows, eq_cols)), + shape=(eq_r, num_vars), + dtype=np.float64, + ) + + self.lp_calls += 1 + result = linprog( + c=c, + A_ub=a_ub, + b_ub=np.array(b_ub, dtype=np.float64) if b_ub else None, + A_eq=a_eq, + b_eq=np.array(b_eq, dtype=np.float64) if b_eq else None, + bounds=[(0.0, None)] * num_vars, + method="highs", + ) + if not result.success or result.fun is None: + return INF + return float(result.fun) + + def heuristic_exact_lp_plus_inactive( + self, + dets: np.ndarray, + available_errs: np.ndarray, + ) -> Tuple[float, Optional[np.ndarray]]: + if not np.any(dets): + return 0.0, None + if not self._has_cover_for_all_active_detectors(dets, available_errs): + return INF, None + + total = 0.0 + for component_active, component_inactive, component_errors in self._reachable_available_components( + dets, + available_errs, + ): + component_value = self._solve_component_exact_lp_plus_inactive( + component_active, + component_inactive, + component_errors, + ) + if math.isinf(component_value): + return INF, None + total += component_value + return float(total), None + + def evaluate_named_heuristic(self, support_data: SupportData, name: str) -> Tuple[float, Optional[np.ndarray]]: + if name == "plain": + return self.heuristic_plain(support_data) + if name == "asc_deg": + return self.heuristic_saturation_zero(support_data, order_kind="asc_deg") + if name == "desc_plain": + return self.heuristic_saturation_zero(support_data, order_kind="desc_plain") + if name == "plain_sweep": + return self.heuristic_plain_sweep(support_data) + if name == "best_of_two": + v1, y1 = self.heuristic_plain_sweep(support_data) + v2, y2 = self.heuristic_saturation_zero(support_data, order_kind="asc_deg") + if v1 >= v2: + return v1, y1 + return v2, y2 + if name == "best_of_three": + candidates = [ + self.heuristic_plain_sweep(support_data), + self.heuristic_saturation_zero(support_data, order_kind="asc_deg"), + self.heuristic_saturation_zero(support_data, order_kind="desc_plain"), + ] + return max(candidates, key=lambda t: t[0]) + if name == "exact_lp": + return self.heuristic_exact_lp(support_data) + raise ValueError(f"Unknown heuristic {name!r}") + + def compute_support_based_heuristic( + self, + dets: np.ndarray, + errs: np.ndarray, + blocked_errs: np.ndarray, + *, + name: Optional[str] = None, + ) -> Tuple[float, Optional[np.ndarray]]: + self.heuristic_calls += 1 + available = self._available_errors(errs, blocked_errs) + heuristic_name = name or self.heuristic_name + if heuristic_name == "exact_lp_plus_inactive": + return self.heuristic_exact_lp_plus_inactive(dets, available) + support_data = self.build_support_data(dets, available) + return self.evaluate_named_heuristic(support_data, heuristic_name) + + def project_child_y( + self, + parent_state: SearchState, + child_dets: np.ndarray, + child_errs: np.ndarray, + child_blocked_errs: np.ndarray, + child_det_counts: np.ndarray, + flipped_detectors: Sequence[int], + ) -> Tuple[float, Optional[np.ndarray], str]: + if parent_state.y_prices is None: + raise AssertionError("Expected a stored feasible y vector before projecting to a child.") + + self.heuristic_calls += 1 + self.projection_heuristic_calls += 1 + available = self._available_errors(child_errs, child_blocked_errs) + if not self._has_cover_for_all_active_detectors(child_dets, available): + return INF, None, "projected" + + y_projected = np.zeros(self.num_detectors, dtype=np.float64) + keep = parent_state.dets & child_dets + y_projected[keep] = parent_state.y_prices[keep] + projected_value = float(y_projected[np.flatnonzero(child_dets)].sum()) + best_value = projected_value + best_y = y_projected + best_source = "projected" + + if self.projection_combine_max_plain: + plain_value, plain_y = self.plain_detcost_from_counts(child_dets, available, child_det_counts) + if plain_y is None: + return INF, None, "plain" + if plain_value > best_value + HEURISTIC_EPS: + best_value = plain_value + best_y = plain_y + best_source = "plain" + + return best_value, best_y, best_source + + def report_root_heuristics(self, dets: np.ndarray, errs: np.ndarray, blocked_errs: np.ndarray) -> List[Tuple[str, float]]: + available = self._available_errors(errs, blocked_errs) + support_data = self.build_support_data(dets, available) + names = [ + "plain", + "asc_deg", + "desc_plain", + "plain_sweep", + "best_of_two", + "best_of_three", + "exact_lp", + "exact_lp_plus_inactive", + ] + out: List[Tuple[str, float]] = [] + saved_lp_calls = self.lp_calls + for name in names: + if name == "exact_lp_plus_inactive": + value, _ = self.heuristic_exact_lp_plus_inactive(dets, available) + else: + value, _ = self.evaluate_named_heuristic(support_data, name) + out.append((name, value)) + self.lp_calls = saved_lp_calls + return out + + def _maybe_refine_node(self, state: SearchState) -> Tuple[SearchState, bool]: + if state.refined or self.heuristic_name == "plain" or not self.lazy_reinsert_heuristics: + return state, False + + previous_source = state.h_source + projectable = self.heuristic_has_projectable_prices(self.heuristic_name) + self.refinement_calls += 1 + new_value, new_y = self.compute_support_based_heuristic( + state.dets, + state.errs, + state.blocked_errs, + name=self.heuristic_name, + ) + if math.isinf(new_value): + if previous_source == "projected": + self.projected_nodes_refined += 1 + if self.verbose_search: + print( + f" refine approx_h={state.h_cost:.6f} new_h=INF delta=INF reinserted=False discarded=True" + ) + state.h_cost = INF + state.h_source = "refined" + if projectable: + state.y_prices = None + state.refined = True + return state, True + if projectable and new_y is None: + raise AssertionError(f"Expected projectable y-prices from heuristic {self.heuristic_name!r}.") + + delta = new_value - state.h_cost + self.total_refinement_gain += max(0.0, delta) + self.max_refinement_gain = max(self.max_refinement_gain, max(0.0, delta)) + + if self.heuristic_name in {"exact_lp", "exact_lp_plus_inactive"} and new_value + 1e-7 < state.h_cost: + raise AssertionError( + f"Exact LP refinement {new_value} below stored projected value {state.h_cost}." + ) + + if new_value > state.h_cost + HEURISTIC_EPS: + if previous_source == "projected": + self.projected_nodes_refined += 1 + state.h_cost = new_value + state.h_source = "refined" + if projectable: + state.y_prices = new_y + state.refined = True + self.reinserts += 1 + if self.verbose_search: + print( + f" refine approx_h={state.h_cost - delta:.6f} new_h={new_value:.6f} delta={delta:.6f} reinserted=True discarded=False" + ) + return state, True + + # Non-improving recomputation: keep the existing projectable feasible point unless the + # selected heuristic returned a fresh one that can still be projected to children. + if previous_source == "projected": + self.projected_nodes_refined += 1 + if projectable and abs(new_value - state.h_cost) <= HEURISTIC_EPS and new_y is not None: + state.y_prices = new_y + state.refined = True + if self.verbose_search: + new_text = "INF" if math.isinf(new_value) else f"{new_value:.6f}" + print( + f" refine approx_h={state.h_cost:.6f} new_h={new_text} delta={delta:.6f} reinserted=False discarded=False" + ) + return state, False + + def decode(self, shot_dets: np.ndarray, det_beam: float = INF) -> DecodeResult: + start_time = time.perf_counter() + self.reset_stats() + + dets0 = np.array(shot_dets, dtype=bool, copy=True) + errs0 = np.zeros(self.num_errors, dtype=bool) + blocked0 = np.zeros(self.num_errors, dtype=bool) + det_counts0 = np.zeros(self.num_errors, dtype=np.uint16) + for d in np.flatnonzero(dets0): + for ei in self.d2e[int(d)]: + det_counts0[int(ei)] += 1 + + root_h, root_y = self.plain_detcost_from_counts(dets0, self._available_errors(errs0, blocked0), det_counts0) + if root_y is None or math.isinf(root_h): + return DecodeResult( + success=False, + errs=errs0, + residual_dets=dets0, + cost=INF, + nodes_pushed=1, + nodes_popped=0, + max_queue_size=1, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + root_refined = (self.heuristic_name == "plain") or (not self.lazy_reinsert_heuristics) + if root_refined and self.heuristic_name != "plain": + # Eager mode: use the selected heuristic immediately. + eager_h, eager_y = self.compute_support_based_heuristic(dets0, errs0, blocked0, name=self.heuristic_name) + if math.isinf(eager_h): + return DecodeResult( + success=False, + errs=errs0, + residual_dets=dets0, + cost=INF, + nodes_pushed=1, + nodes_popped=0, + max_queue_size=1, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + if self.heuristic_has_projectable_prices(self.heuristic_name): + if eager_y is None: + raise AssertionError(f"Expected projectable y-prices from heuristic {self.heuristic_name!r}.") + root_y = eager_y + root_h = eager_h + + root_state = SearchState( + errs=errs0, + blocked_errs=blocked0, + dets=dets0, + det_counts=det_counts0, + g_cost=0.0, + h_cost=root_h, + h_source="plain" if not root_refined else ("plain" if self.heuristic_name == "plain" else "refined"), + refined=root_refined, + y_prices=root_y, + ) + + heap: List[Tuple[float, int, int, SearchState]] = [] + counter = 0 + heapq.heappush(heap, (root_state.g_cost + root_state.h_cost, int(dets0.sum()), counter, root_state)) + counter += 1 + nodes_pushed = 1 + nodes_popped = 0 + max_queue_size = 1 + min_num_dets = int(dets0.sum()) + + while heap: + max_queue_size = max(max_queue_size, len(heap)) + f_cost, num_dets, _entry_id, state = heapq.heappop(heap) + nodes_popped += 1 + max_num_dets = min_num_dets + det_beam + if num_dets > max_num_dets: + continue + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = min_num_dets + det_beam + + if self.verbose_search: + projected_unrefined = self.projected_nodes_generated - self.projected_nodes_refined + print( + f"len(heap)={len(heap)} nodes_pushed={nodes_pushed} nodes_popped={nodes_popped} " + f"lp_calls={self.lp_calls} reinserts={self.reinserts} proj_generated={self.projected_nodes_generated} " + f"proj_refined={self.projected_nodes_refined} proj_unrefined_so_far={projected_unrefined} " + f"num_dets={num_dets} max_num_dets={max_num_dets} f={f_cost:.6f} g={state.g_cost:.6f} " + f"h={state.h_cost:.6f} h_source={state.h_source} refined={state.refined}" + ) + + if num_dets == 0: + return DecodeResult( + success=True, + errs=state.errs, + residual_dets=state.dets, + cost=state.g_cost, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + max_queue_size=max_queue_size, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + state, should_reinsert = self._maybe_refine_node(state) + if should_reinsert: + if state.y_prices is None or math.isinf(state.h_cost): + if state.h_source == "projected": + self.projected_nodes_refined += 1 + continue + if state.h_source != "plain": + heapq.heappush(heap, (state.g_cost + state.h_cost, num_dets, counter, state)) + counter += 1 + continue + + min_det = int(np.flatnonzero(state.dets)[0]) + prefix_blocked = state.blocked_errs.copy() + children_generated = 0 + children_beam_pruned = 0 + children_infeasible = 0 + children_projected = 0 + + for ei in self.d2e[min_det]: + ei = int(ei) + prefix_blocked[ei] = True + if state.errs[ei] or state.blocked_errs[ei]: + continue + + child_errs = state.errs.copy() + child_errs[ei] = True + child_blocked = prefix_blocked.copy() + child_dets = state.dets.copy() + child_det_counts = state.det_counts.copy() + for d in self.error_detectors[ei]: + d = int(d) + if child_dets[d]: + child_dets[d] = False + for oei in self.d2e[d]: + child_det_counts[int(oei)] -= 1 + else: + child_dets[d] = True + for oei in self.d2e[d]: + child_det_counts[int(oei)] += 1 + + child_num_dets = int(child_dets.sum()) + if child_num_dets > max_num_dets: + children_beam_pruned += 1 + continue + + child_g = state.g_cost + float(self.weights[ei]) + if self.heuristic_name == "plain" or (not self.lazy_reinsert_heuristics): + child_h, child_y = self.compute_support_based_heuristic( + child_dets, child_errs, child_blocked, name=self.heuristic_name + ) + child_source = "plain" if self.heuristic_name == "plain" else "refined" + child_refined = True + else: + if state.y_prices is None: + raise AssertionError("Expected parent feasible y-prices before projecting to child.") + child_h, child_y, child_source = self.project_child_y( + state, + child_dets, + child_errs, + child_blocked, + child_det_counts, + self.error_detectors[ei], + ) + self.projected_nodes_generated += 1 + children_projected += 1 + child_refined = False + + if math.isinf(child_h): + children_infeasible += 1 + continue + if ( + child_refined + and self.heuristic_has_projectable_prices(self.heuristic_name) + and child_y is None + ): + raise AssertionError(f"Expected projectable y-prices from heuristic {self.heuristic_name!r}.") + + child_state = SearchState( + errs=child_errs, + blocked_errs=child_blocked, + dets=child_dets, + det_counts=child_det_counts, + g_cost=child_g, + h_cost=child_h, + h_source=child_source, + refined=child_refined, + y_prices=child_y, + ) + heapq.heappush(heap, (child_g + child_h, child_num_dets, counter, child_state)) + counter += 1 + nodes_pushed += 1 + children_generated += 1 + + if self.verbose_search: + projected_unrefined = self.projected_nodes_generated - self.projected_nodes_refined + print( + f" expanded children_generated={children_generated} children_projected={children_projected} " + f"beam_pruned={children_beam_pruned} infeasible={children_infeasible} " + f"lp_calls={self.lp_calls} proj_unrefined_so_far={projected_unrefined}" + ) + + return DecodeResult( + success=False, + errs=np.zeros(self.num_errors, dtype=bool), + residual_dets=np.array(shot_dets, dtype=bool, copy=True), + cost=INF, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + max_queue_size=max_queue_size, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + def cost_from_errs(self, errs: np.ndarray) -> float: + return float(self.weights[errs].sum()) + + def detectors_from_errs(self, errs: np.ndarray) -> np.ndarray: + dets = np.zeros(self.num_detectors, dtype=bool) + for ei in np.flatnonzero(errs): + for d in self.error_detectors[int(ei)]: + dets[d] ^= True + return dets + + def observables_from_errs(self, errs: np.ndarray) -> np.ndarray: + parity: Dict[int, bool] = {} + for ei in np.flatnonzero(errs): + for obs in self.error_observables[int(ei)]: + parity[int(obs)] = not parity.get(int(obs), False) + return np.array(sorted(obs for obs, bit in parity.items() if bit), dtype=np.int32) + + +def sample_detections_and_observables( + circuit: stim.Circuit, + *, + num_shots: int, + seed: int, + num_detectors: int, + num_observables: int, +) -> Tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets_packed, obs_packed = sampler.sample( + shots=num_shots, + separate_observables=True, + bit_packed=True, + ) + dets_unpacked = np.unpackbits( + dets_packed, + bitorder="little", + axis=1, + count=num_detectors, + ) + obs_unpacked = np.unpackbits( + obs_packed, + bitorder="little", + axis=1, + count=num_observables, + ) + return dets_unpacked.astype(bool), obs_unpacked.astype(bool) + + +def parse_det_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "infinity", "none"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("det-beam must be non-negative or 'inf'.") + return float(value) + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Prototype A* decoder for Stim circuits using greedy singleton-budget heuristics." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a .stim circuit file.") + parser.add_argument( + "--dets", + type=str, + default=None, + help="String of shot dets (e.g., 'shot D0 D1 L2') to parse instead of sampling.", + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=100, + help="Number of shots to sample from Stim before selecting --shot (default: 100).", + ) + parser.add_argument( + "--shot", + type=int, + default=0, + help="Index of the sampled shot to decode (default: 0).", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Stim sampler seed (default: 27123839530).", + ) + parser.add_argument( + "--det-beam", + type=parse_det_beam, + default=INF, + help="Beam cutoff on the residual detector count; use 'inf' for none.", + ) + parser.add_argument( + "--heuristic", + choices=[ + "plain", + "asc_deg", + "desc_plain", + "plain_sweep", + "best_of_two", + "best_of_three", + "exact_lp", + "exact_lp_plus_inactive", + ], + default="best_of_two", + help="Lower-bound heuristic to use during A* search (default: best_of_two).", + ) + parser.add_argument( + "--lazy-reinsert-heuristics", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "For non-plain heuristics, seed nodes with plain detcost, refine on pop, and reinsert when the selected " + "heuristic improves the key (default: enabled)." + ), + ) + parser.add_argument( + "--projection-combine-max-plain", + action=argparse.BooleanOptionalAction, + default=True, + help="When projecting parent y-prices to a child, take max(projected, plain detcost) (default: enabled).", + ) + parser.add_argument( + "--merge-errors", + action=argparse.BooleanOptionalAction, + default=True, + help="Merge indistinguishable DEM errors before decoding (default: enabled).", + ) + parser.add_argument( + "--respect-blocked-errors-in-heuristic", + action=argparse.BooleanOptionalAction, + default=True, + help="Exclude precedence-blocked errors when forming the lower bound (default: enabled).", + ) + parser.add_argument( + "--report-all-root-heuristics", + action="store_true", + help="Print all root-node heuristic values, including exact_lp and exact_lp_plus_inactive, for the selected shot.", + ) + parser.add_argument( + "--skip-decode", + action="store_true", + help="Only report root heuristics; do not run A* search.", + ) + parser.add_argument( + "--show-shot-detectors", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the selected shot's active detector IDs (default: enabled).", + ) + parser.add_argument( + "--show-error-indices", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the decoded merged-error indices when decoding succeeds (default: enabled).", + ) + parser.add_argument( + "--verbose-search", + action="store_true", + help="Print per-node search diagnostics.", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + if args.shot >= args.sample_num_shots: + parser.error("--shot must be smaller than --sample-num-shots.") + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + errors = merged_errors_from_dem(dem) if args.merge_errors else list(iter_dem_errors_from_dem(dem)) + + if args.dets is not None: + shot_dets = np.zeros(dem.num_detectors, dtype=bool) + shot_obs = np.zeros(dem.num_observables, dtype=bool) + for token in args.dets.split(): + if token == "shot": + continue + if token.startswith("D") and token[1:].isdigit(): + d_idx = int(token[1:]) + if d_idx < dem.num_detectors: + shot_dets[d_idx] = True + elif token.startswith("L") and token[1:].isdigit(): + l_idx = int(token[1:]) + if l_idx < dem.num_observables: + shot_obs[l_idx] = True + else: + dets, obs = sample_detections_and_observables( + circuit, + num_shots=args.sample_num_shots, + seed=args.seed, + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + ) + shot_dets = dets[args.shot] + shot_obs = obs[args.shot] + + decoder = GreedySingletonHeuristicDecoder( + errors, + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + heuristic=args.heuristic, + respect_blocked_errors_in_heuristic=args.respect_blocked_errors_in_heuristic, + lazy_reinsert_heuristics=args.lazy_reinsert_heuristics, + projection_combine_max_plain=args.projection_combine_max_plain, + verbose_search=args.verbose_search, + ) + + print(f"circuit = {args.circuit}") + print(f"heuristic = {args.heuristic}") + print(f"mode = {decoder.mode_name}") + print(f"sample_num_shots = {args.sample_num_shots}") + print(f"shot = {args.shot}") + print(f"num_errors = {decoder.num_errors}") + print(f"num_detectors = {decoder.num_detectors}") + print(f"num_observables = {decoder.num_observables}") + print(f"det_beam = {args.det_beam}") + print(f"merge_errors = {args.merge_errors}") + print(f"respect_blocked_errors_in_heuristic = {args.respect_blocked_errors_in_heuristic}") + print(f"lazy_reinsert_heuristics = {args.lazy_reinsert_heuristics}") + print(f"projection_combine_max_plain = {args.projection_combine_max_plain}") + + if args.show_shot_detectors: + active_dets = np.flatnonzero(shot_dets) + print("shot_detectors =", " ".join(f"D{d}" for d in active_dets)) + + if args.report_all_root_heuristics: + root_errs = np.zeros(decoder.num_errors, dtype=bool) + root_blocked = np.zeros(decoder.num_errors, dtype=bool) + report = decoder.report_root_heuristics(shot_dets, root_errs, root_blocked) + exact = next((v for k, v in report if k == "exact_lp"), None) + print("root_heuristics:") + for name, value in report: + if exact is not None and not math.isinf(exact) and exact > 0: + ratio = value / exact if not math.isinf(value) else INF + ratio_text = "INF" if math.isinf(ratio) else f"{ratio:.6f}" + else: + ratio_text = "n/a" + value_text = "INF" if math.isinf(value) else f"{value:.12f}" + print(f" {name:>24s} value={value_text} ratio_to_exact={ratio_text}") + + if args.skip_decode: + return 0 + + result = decoder.decode(shot_dets, det_beam=args.det_beam) + print(f"success = {result.success}") + print(f"nodes_pushed = {result.nodes_pushed}") + print(f"nodes_popped = {result.nodes_popped}") + print(f"max_queue_size = {result.max_queue_size}") + print(f"heuristic_calls = {result.heuristic_calls}") + print(f"plain_heuristic_calls = {result.plain_heuristic_calls}") + print(f"projection_heuristic_calls = {result.projection_heuristic_calls}") + print(f"refinement_calls = {result.refinement_calls}") + print(f"lp_calls = {result.lp_calls}") + print(f"reinserts = {result.reinserts}") + print(f"projected_nodes_generated = {result.projected_nodes_generated}") + print(f"projected_nodes_refined = {result.projected_nodes_refined}") + print(f"projected_nodes_unrefined_at_finish = {result.projected_nodes_unrefined_at_finish}") + print(f"total_refinement_gain = {result.total_refinement_gain:.6f}") + print(f"max_refinement_gain = {result.max_refinement_gain:.6f}") + print(f"elapsed_seconds = {result.elapsed_seconds:.6f}") + + if not result.success: + print("decode failed") + return 1 + + if args.show_error_indices: + print("decoded_error_indices =", " ".join(map(str, np.flatnonzero(result.errs).tolist()))) + + reproduced_dets = decoder.detectors_from_errs(result.errs) + if not np.array_equal(reproduced_dets, shot_dets): + raise AssertionError("Decoded errors do not reproduce the sampled detection events.") + + decoded_cost = decoder.cost_from_errs(result.errs) + predicted_obs = decoder.observables_from_errs(result.errs) + sampled_obs = np.flatnonzero(shot_obs) + + print(f"num_decoded_errors = {int(result.errs.sum())}") + print(f"decoded_cost = {decoded_cost:.12f}") + print("predicted_observables =", " ".join(f"L{o}" for o in predicted_obs.tolist())) + print("sampled_observables =", " ".join(f"L{o}" for o in sampled_obs.tolist())) + print(f"observables_match = {bool(np.array_equal(predicted_obs, sampled_obs))}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/astar_singleton_lp_probe.py b/src/py/astar/astar_singleton_lp_probe.py new file mode 100644 index 0000000..8b321e0 --- /dev/null +++ b/src/py/astar/astar_singleton_lp_probe.py @@ -0,0 +1,1276 @@ +#!/usr/bin/env python3 +"""Instrumented A* prototype for studying the optimal singleton LP heuristic. + +This script is intentionally data-heavy and not heavily optimized. It decodes a +set of Stim circuits, samples several shots from each, and writes detailed logs +about every heuristic evaluation during search. + +Outputs (written under --output-dir): + manifest.json + shot_summaries.jsonl + node_summaries.jsonl.gz + component_summaries.jsonl.gz + sampled_instances.jsonl.gz + +The node/component logs are designed to answer questions such as: + * How often is the singleton LP graphlike (all distinct supports have size <= 2)? + * How many connected components does the residual support hypergraph have? + * How many raw allowed errors collapse to the same distinct active support? + * How sparse are primal/dual LP solutions? + * Are graphlike components common enough to justify a specialized solver? + +The search tree uses the same precedence-style pruning idea as the prototype and +Tesseract paper: at each node, only errors incident to the minimum active +residual detector are expanded, with earlier siblings blocked to keep a unique +path ordering. The A* heuristic can be plain detcost or the optimal singleton +LP; both values are logged for every created node. +""" + +from __future__ import annotations + +import argparse +import gzip +import heapq +import json +import math +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +import stim +from scipy import sparse +from scipy.optimize import linprog + +INF = math.inf +JSON_SEPARATORS = (",", ":") +LP_TOL = 1e-9 +RATIONAL_TOL = 1e-7 + + +@dataclass(frozen=True) +class MergedError: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class DecoderData: + num_detectors: int + num_observables: int + errors: List[MergedError] + detector_to_errors: List[List[int]] + error_costs: np.ndarray + error_detectors: List[Tuple[int, ...]] + error_observables: List[Tuple[int, ...]] + + +@dataclass +class SearchSettings: + det_beam: float + search_heuristic: str + respect_blocked_errors_in_heuristic: bool + max_nodes_popped: Optional[int] + max_nodes_pushed: Optional[int] + sample_raw_nodes_per_shot: int + verbose_search: bool + + +@dataclass +class SearchState: + node_id: int + parent_node_id: Optional[int] + incoming_error_index: Optional[int] + depth: int + activated_errors: Tuple[int, ...] + activated_error_mask: np.ndarray + blocked_errors: np.ndarray + active_detectors: np.ndarray + active_detector_counts: np.ndarray + path_cost: float + search_h: float + plain_h: float + opt_h: float + + +class JsonlWriter: + def __init__(self, path: Path, *, gz: bool = False): + self.path = path + path.parent.mkdir(parents=True, exist_ok=True) + if gz: + self.file = gzip.open(path, "wt", encoding="utf-8") + else: + self.file = open(path, "wt", encoding="utf-8") + + def write(self, record: Dict[str, Any]) -> None: + self.file.write(json.dumps(record, separators=JSON_SEPARATORS, sort_keys=True)) + self.file.write("\n") + + def flush(self) -> None: + self.file.flush() + + def close(self) -> None: + self.file.close() + + +class UnionFind: + def __init__(self, size: int) -> None: + self.parent = list(range(size)) + self.rank = [0] * size + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra = self.find(a) + rb = self.find(b) + if ra == rb: + return + if self.rank[ra] < self.rank[rb]: + self.parent[ra] = rb + elif self.rank[ra] > self.rank[rb]: + self.parent[rb] = ra + else: + self.parent[rb] = ra + self.rank[ra] += 1 + + +class ShotAggregator: + def __init__(self) -> None: + self.nodes_created = 0 + self.nodes_pushed = 0 + self.nodes_infeasible = 0 + self.nodes_graphlike = 0 + self.nodes_with_lp = 0 + self.total_plain_h = 0.0 + self.total_opt_h = 0.0 + self.total_h_gain = 0.0 + self.total_lp_time_sec = 0.0 + self.total_lp_vars = 0 + self.total_lp_constraints = 0 + self.total_raw_allowed_errors = 0 + self.total_distinct_supports = 0 + self.total_components = 0 + self.total_graphlike_components = 0 + self.max_active_detectors = 0 + self.max_distinct_supports = 0 + self.max_component_variables = 0 + self.max_component_constraints = 0 + + def absorb_node(self, node_record: Dict[str, Any]) -> None: + self.nodes_created += 1 + self.nodes_pushed += int(bool(node_record["pushed"])) + self.nodes_infeasible += int(bool(node_record["opt_infeasible"])) + self.nodes_graphlike += int(bool(node_record["graphlike_all_components"])) + self.nodes_with_lp += int(node_record["lp_calls"] > 0) + self.total_plain_h += float(node_record["plain_h"]) + if not node_record["opt_infeasible"]: + self.total_opt_h += float(node_record["opt_h"]) + self.total_h_gain += float(node_record["opt_h"] - node_record["plain_h"]) + self.total_lp_time_sec += float(node_record["lp_time_sec"]) + self.total_lp_vars += int(node_record["total_lp_vars"]) + self.total_lp_constraints += int(node_record["total_lp_constraints"]) + self.total_raw_allowed_errors += int(node_record["raw_allowed_errors"]) + self.total_distinct_supports += int(node_record["distinct_supports"]) + self.total_components += int(node_record["num_components"]) + self.total_graphlike_components += int(node_record["num_graphlike_components"]) + self.max_active_detectors = max(self.max_active_detectors, int(node_record["num_active_detectors"])) + self.max_distinct_supports = max(self.max_distinct_supports, int(node_record["distinct_supports"])) + self.max_component_variables = max(self.max_component_variables, int(node_record["max_component_variables"])) + self.max_component_constraints = max(self.max_component_constraints, int(node_record["max_component_constraints"])) + + def finish(self, *, nodes_popped: int, status: str, elapsed_seconds: float) -> Dict[str, Any]: + n = max(self.nodes_created, 1) + c = max(self.total_components, 1) + return { + "status": status, + "nodes_created": self.nodes_created, + "nodes_pushed": self.nodes_pushed, + "nodes_popped": nodes_popped, + "nodes_infeasible": self.nodes_infeasible, + "graphlike_node_fraction": self.nodes_graphlike / n, + "mean_plain_h": self.total_plain_h / n, + "mean_opt_h_over_feasible": (self.total_opt_h / max(self.nodes_created - self.nodes_infeasible, 1)), + "mean_opt_minus_plain_over_feasible": (self.total_h_gain / max(self.nodes_created - self.nodes_infeasible, 1)), + "total_lp_time_sec": self.total_lp_time_sec, + "mean_lp_time_per_created_node_sec": self.total_lp_time_sec / n, + "mean_lp_vars_per_created_node": self.total_lp_vars / n, + "mean_lp_constraints_per_created_node": self.total_lp_constraints / n, + "mean_raw_allowed_errors": self.total_raw_allowed_errors / n, + "mean_distinct_supports": self.total_distinct_supports / n, + "mean_components": self.total_components / n, + "graphlike_component_fraction": self.total_graphlike_components / c, + "max_active_detectors": self.max_active_detectors, + "max_distinct_supports": self.max_distinct_supports, + "max_component_variables": self.max_component_variables, + "max_component_constraints": self.max_component_constraints, + "elapsed_seconds": elapsed_seconds, + } + + +class NodeSampler: + def __init__(self, sample_raw_nodes_per_shot: int): + self.sample_raw_nodes_per_shot = sample_raw_nodes_per_shot + self.seen = 0 + + def should_sample(self, node_id: int) -> bool: + del node_id + if self.seen < self.sample_raw_nodes_per_shot: + self.seen += 1 + return True + return False + + +class ProbeLogger: + def __init__(self, output_dir: Path): + self.output_dir = output_dir + self.shot_writer = JsonlWriter(output_dir / "shot_summaries.jsonl", gz=False) + self.node_writer = JsonlWriter(output_dir / "node_summaries.jsonl.gz", gz=True) + self.component_writer = JsonlWriter(output_dir / "component_summaries.jsonl.gz", gz=True) + self.sample_writer = JsonlWriter(output_dir / "sampled_instances.jsonl.gz", gz=True) + + def close(self) -> None: + self.shot_writer.close() + self.node_writer.close() + self.component_writer.close() + self.sample_writer.close() + + def flush(self) -> None: + self.shot_writer.flush() + self.node_writer.flush() + self.component_writer.flush() + self.sample_writer.flush() + + +def parse_optional_int(text: str) -> Optional[int]: + lowered = text.strip().lower() + if lowered in {"none", "inf", "infinity", "+inf", "+infinity"}: + return None + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("must be non-negative or one of: none, inf") + return value + + +def parse_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "infinity", "+inf", "+infinity"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("beam must be non-negative or 'inf'") + return float(value) + + +def xor_probability(p0: float, p1: float) -> float: + return p0 * (1 - p1) + (1 - p0) * p1 + + +def iter_dem_errors(dem: stim.DetectorErrorModel) -> Iterable[MergedError]: + for instruction in dem.flattened(): + if instruction.type != "error": + continue + probability = float(instruction.args_copy()[0]) + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + "This prototype assumes detector-error-model probabilities are in (0, 0.5)." + ) + detectors: set[int] = set() + observables: set[int] = set() + for target in instruction.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + assert target.is_relative_detector_id() + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + yield MergedError( + probability=probability, + likelihood_cost=float(-math.log(probability / (1 - probability))), + detectors=tuple(sorted(detectors)), + observables=tuple(sorted(observables)), + ) + + +def merged_errors(dem: stim.DetectorErrorModel) -> List[MergedError]: + probabilities: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + for error in iter_dem_errors(dem): + key = (error.detectors, error.observables) + prev = probabilities.get(key) + probabilities[key] = error.probability if prev is None else xor_probability(prev, error.probability) + + out: List[MergedError] = [] + for (detectors, observables), probability in probabilities.items(): + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError("Merged error has probability >= 0.5.") + out.append( + MergedError( + probability=probability, + likelihood_cost=float(-math.log(probability / (1 - probability))), + detectors=detectors, + observables=observables, + ) + ) + return out + + +def build_decoder_data(dem: stim.DetectorErrorModel, *, merge_errors_in_dem: bool = True) -> DecoderData: + errors = merged_errors(dem) if merge_errors_in_dem else list(iter_dem_errors(dem)) + detector_to_errors: List[List[int]] = [[] for _ in range(dem.num_detectors)] + for ei, error in enumerate(errors): + for d in error.detectors: + detector_to_errors[d].append(ei) + return DecoderData( + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + errors=errors, + detector_to_errors=detector_to_errors, + error_costs=np.asarray([e.likelihood_cost for e in errors], dtype=np.float64), + error_detectors=[e.detectors for e in errors], + error_observables=[e.observables for e in errors], + ) + + +def unpack_bit_packed_rows(bits: np.ndarray, count: int) -> np.ndarray: + return np.unpackbits(bits, bitorder="little", axis=1, count=count).astype(bool, copy=False) + + +def initial_detector_counts(data: DecoderData, active_detectors: np.ndarray) -> np.ndarray: + counts = np.zeros(len(data.errors), dtype=np.int32) + for d in np.flatnonzero(active_detectors): + for ei in data.detector_to_errors[int(d)]: + counts[ei] += 1 + return counts + + +def apply_error( + data: DecoderData, + active_detectors: np.ndarray, + active_detector_counts: np.ndarray, + error_index: int, +) -> Tuple[np.ndarray, np.ndarray]: + next_detectors = active_detectors.copy() + next_counts = active_detector_counts.copy() + for d in data.error_detectors[error_index]: + if next_detectors[d]: + next_detectors[d] = False + delta = -1 + else: + next_detectors[d] = True + delta = 1 + for other_error_index in data.detector_to_errors[d]: + next_counts[other_error_index] += delta + return next_detectors, next_counts + + +def plain_detcost_for_detector( + data: DecoderData, + detector: int, + *, + activated_error_mask: np.ndarray, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, + respect_blocked_errors_in_heuristic: bool, +) -> float: + best = INF + for ei in data.detector_to_errors[detector]: + if respect_blocked_errors_in_heuristic: + if blocked_errors[ei]: + continue + else: + if activated_error_mask[ei]: + continue + count = int(active_detector_counts[ei]) + assert count > 0 + candidate = float(data.error_costs[ei]) / count + if candidate < best: + best = candidate + return best + + +def plain_detcost_heuristic( + data: DecoderData, + active_detectors: np.ndarray, + *, + activated_error_mask: np.ndarray, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, + respect_blocked_errors_in_heuristic: bool, +) -> float: + total = 0.0 + for d in np.flatnonzero(active_detectors): + det_cost = plain_detcost_for_detector( + data=data, + detector=int(d), + activated_error_mask=activated_error_mask, + blocked_errors=blocked_errors, + active_detector_counts=active_detector_counts, + respect_blocked_errors_in_heuristic=respect_blocked_errors_in_heuristic, + ) + if det_cost == INF: + return INF + total += det_cost + return total + + +def grid_fraction(values: np.ndarray, denominator: int, tol: float = RATIONAL_TOL) -> float: + if values.size == 0: + return 0.0 + scaled = denominator * values + return float(np.mean(np.abs(scaled - np.round(scaled)) <= tol)) + + +@dataclass +class LPProbeResult: + opt_h: float + node_record: Dict[str, Any] + component_records: List[Dict[str, Any]] + sample_record: Optional[Dict[str, Any]] + + +def probe_opt_singleton_lp( + *, + run_id: str, + circuit_name: str, + shot_index: int, + state: SearchState, + data: DecoderData, + settings: SearchSettings, + plain_h: float, + sample_raw_instance: bool, +) -> LPProbeResult: + active_detector_ids = np.flatnonzero(state.active_detectors) + num_active_detectors = int(active_detector_ids.size) + global_to_local = np.full(data.num_detectors, -1, dtype=np.int32) + global_to_local[active_detector_ids] = np.arange(num_active_detectors, dtype=np.int32) + + support_to_cost: Dict[Tuple[int, ...], float] = {} + support_to_multiplicity: Dict[Tuple[int, ...], int] = {} + covered = np.zeros(num_active_detectors, dtype=bool) + + raw_allowed_errors = 0 + raw_support_size_hist = {"1": 0, "2": 0, "3": 0, "4+": 0} + + for ei, error_detectors in enumerate(data.error_detectors): + if settings.respect_blocked_errors_in_heuristic: + if state.blocked_errors[ei]: + continue + else: + if state.activated_error_mask[ei]: + continue + + count = int(state.active_detector_counts[ei]) + if count == 0: + continue + support = tuple(int(global_to_local[d]) for d in error_detectors if state.active_detectors[d]) + assert support + raw_allowed_errors += 1 + size = len(support) + if size == 1: + raw_support_size_hist["1"] += 1 + elif size == 2: + raw_support_size_hist["2"] += 1 + elif size == 3: + raw_support_size_hist["3"] += 1 + else: + raw_support_size_hist["4+"] += 1 + covered[list(support)] = True + support_to_multiplicity[support] = support_to_multiplicity.get(support, 0) + 1 + cost = float(data.error_costs[ei]) + prev = support_to_cost.get(support) + if prev is None or cost < prev: + support_to_cost[support] = cost + + distinct_support_size_hist = {"1": 0, "2": 0, "3": 0, "4+": 0} + for support in support_to_cost: + size = len(support) + if size == 1: + distinct_support_size_hist["1"] += 1 + elif size == 2: + distinct_support_size_hist["2"] += 1 + elif size == 3: + distinct_support_size_hist["3"] += 1 + else: + distinct_support_size_hist["4+"] += 1 + + uncovered_count = int(np.count_nonzero(~covered)) + base_node_record: Dict[str, Any] = { + "run_id": run_id, + "circuit": circuit_name, + "shot": shot_index, + "node_id": state.node_id, + "parent_node_id": state.parent_node_id, + "incoming_error_index": state.incoming_error_index, + "depth": state.depth, + "num_active_detectors": num_active_detectors, + "path_cost": state.path_cost, + "plain_h": plain_h, + "raw_allowed_errors": raw_allowed_errors, + "raw_support_hist": raw_support_size_hist, + "distinct_supports": len(support_to_cost), + "distinct_support_hist": distinct_support_size_hist, + "support_multiplicity_mean": (float(np.mean(list(support_to_multiplicity.values()))) if support_to_multiplicity else 0.0), + "support_multiplicity_max": (max(support_to_multiplicity.values()) if support_to_multiplicity else 0), + "uncovered_active_detectors": uncovered_count, + } + + if uncovered_count > 0: + base_node_record.update( + { + "opt_h": INF, + "opt_infeasible": True, + "lp_calls": 0, + "lp_time_sec": 0.0, + "total_lp_vars": 0, + "total_lp_constraints": 0, + "num_components": 0, + "num_graphlike_components": 0, + "graphlike_all_components": False, + "max_support_size": 0, + "max_component_variables": 0, + "max_component_constraints": 0, + "positive_y_count": 0, + "tight_constraint_count": 0, + "positive_dual_count": 0, + } + ) + sample_record = None + if sample_raw_instance: + sample_record = { + "run_id": run_id, + "circuit": circuit_name, + "shot": shot_index, + "node_id": state.node_id, + "parent_node_id": state.parent_node_id, + "depth": state.depth, + "opt_infeasible": True, + "active_detector_ids": active_detector_ids.tolist(), + "supports": [ + { + "local_support": list(support), + "global_support": [int(active_detector_ids[i]) for i in support], + "cost": support_to_cost[support], + "multiplicity": support_to_multiplicity[support], + } + for support in sorted(support_to_cost) + ], + } + return LPProbeResult( + opt_h=INF, + node_record=base_node_record, + component_records=[], + sample_record=sample_record, + ) + + union_find = UnionFind(num_active_detectors) + for support in support_to_cost: + first = support[0] + for detector in support[1:]: + union_find.union(first, detector) + + detectors_by_root: Dict[int, List[int]] = {} + for detector in range(num_active_detectors): + root = union_find.find(detector) + detectors_by_root.setdefault(root, []).append(detector) + + supports_by_root: Dict[int, List[Tuple[Tuple[int, ...], float, int]]] = {} + for support, cost in support_to_cost.items(): + root = union_find.find(support[0]) + supports_by_root.setdefault(root, []).append((support, cost, support_to_multiplicity[support])) + + component_records: List[Dict[str, Any]] = [] + sample_components: List[Dict[str, Any]] = [] + total_opt_h = 0.0 + total_lp_time = 0.0 + total_lp_vars = 0 + total_lp_constraints = 0 + total_positive_y = 0 + total_tight_constraints = 0 + total_positive_dual = 0 + num_graphlike_components = 0 + max_component_variables = 0 + max_component_constraints = 0 + max_support_size = max((len(support) for support in support_to_cost), default=0) + + for component_index, (root, component_detectors) in enumerate(sorted(detectors_by_root.items())): + local_reindex = {detector: i for i, detector in enumerate(component_detectors)} + component_supports = supports_by_root[root] + num_vars = len(component_detectors) + num_constraints = len(component_supports) + max_component_variables = max(max_component_variables, num_vars) + max_component_constraints = max(max_component_constraints, num_constraints) + total_lp_vars += num_vars + total_lp_constraints += num_constraints + + row_indices: List[int] = [] + col_indices: List[int] = [] + values: List[float] = [] + rhs = np.empty(num_constraints, dtype=np.float64) + component_global_supports: List[List[int]] = [] + support_sizes = np.empty(num_constraints, dtype=np.int32) + multiplicities = np.empty(num_constraints, dtype=np.int32) + + graphlike = True + support_size_hist = {"1": 0, "2": 0, "3": 0, "4+": 0} + for row, (support, cost, multiplicity) in enumerate(component_supports): + rhs[row] = cost + multiplicities[row] = multiplicity + reindexed_support = [local_reindex[d] for d in support] + support_sizes[row] = len(reindexed_support) + if support_sizes[row] == 1: + support_size_hist["1"] += 1 + elif support_sizes[row] == 2: + support_size_hist["2"] += 1 + elif support_sizes[row] == 3: + support_size_hist["3"] += 1 + graphlike = False + else: + support_size_hist["4+"] += 1 + graphlike = False + component_global_supports.append([int(active_detector_ids[d]) for d in support]) + for col in reindexed_support: + row_indices.append(row) + col_indices.append(col) + values.append(1.0) + + a_ub = sparse.csr_matrix( + (values, (row_indices, col_indices)), + shape=(num_constraints, num_vars), + dtype=np.float64, + ) + + t0 = time.perf_counter() + result = linprog( + c=-np.ones(num_vars, dtype=np.float64), + A_ub=a_ub, + b_ub=rhs, + bounds=[(0.0, None)] * num_vars, + method="highs", + ) + lp_time_sec = time.perf_counter() - t0 + total_lp_time += lp_time_sec + if not result.success: + raise RuntimeError( + f"LP solve failed for circuit={circuit_name} shot={shot_index} node={state.node_id} " + f"component={component_index}: {result.message}" + ) + + y = np.asarray(result.x, dtype=np.float64) + total_opt_h += float(-result.fun) + positive_y_mask = y > LP_TOL + positive_y_count = int(np.count_nonzero(positive_y_mask)) + total_positive_y += positive_y_count + + if hasattr(result, "ineqlin") and hasattr(result.ineqlin, "residual"): + residual = np.asarray(result.ineqlin.residual, dtype=np.float64) + else: + residual = rhs - a_ub.dot(y) + tight_mask = residual <= LP_TOL + tight_count = int(np.count_nonzero(tight_mask)) + total_tight_constraints += tight_count + + if hasattr(result, "ineqlin") and hasattr(result.ineqlin, "marginals"): + dual = -np.asarray(result.ineqlin.marginals, dtype=np.float64) + else: + dual = np.full(num_constraints, np.nan) + if np.isnan(dual).any(): + positive_dual_mask = np.zeros(num_constraints, dtype=bool) + positive_dual = np.zeros(0, dtype=np.float64) + else: + positive_dual_mask = dual > LP_TOL + positive_dual = dual[positive_dual_mask] + positive_dual_count = int(np.count_nonzero(positive_dual_mask)) + total_positive_dual += positive_dual_count + + if graphlike: + num_graphlike_components += 1 + + positive_dual_size_hist = {"1": 0, "2": 0, "3": 0, "4+": 0} + for size in support_sizes[positive_dual_mask]: + if size == 1: + positive_dual_size_hist["1"] += 1 + elif size == 2: + positive_dual_size_hist["2"] += 1 + elif size == 3: + positive_dual_size_hist["3"] += 1 + else: + positive_dual_size_hist["4+"] += 1 + + component_record = { + "run_id": run_id, + "circuit": circuit_name, + "shot": shot_index, + "node_id": state.node_id, + "component_index": component_index, + "num_variables": num_vars, + "num_constraints": num_constraints, + "objective": float(-result.fun), + "lp_time_sec": lp_time_sec, + "graphlike": graphlike, + "max_support_size": int(np.max(support_sizes) if support_sizes.size else 0), + "support_hist": support_size_hist, + "positive_y_count": positive_y_count, + "tight_constraint_count": tight_count, + "positive_dual_count": positive_dual_count, + "dual_integral_fraction": grid_fraction(positive_dual, 1), + "dual_half_integral_fraction": grid_fraction(positive_dual, 2), + "dual_third_integral_fraction": grid_fraction(positive_dual, 3), + "dual_quarter_integral_fraction": grid_fraction(positive_dual, 4), + "positive_dual_support_hist": positive_dual_size_hist, + "support_multiplicity_mean": float(np.mean(multiplicities)) if multiplicities.size else 0.0, + "support_multiplicity_max": int(np.max(multiplicities) if multiplicities.size else 0), + } + component_records.append(component_record) + + if sample_raw_instance: + sample_components.append( + { + "component_index": component_index, + "global_detector_ids": [int(active_detector_ids[d]) for d in component_detectors], + "supports": [ + { + "global_support": component_global_supports[row], + "cost": float(rhs[row]), + "multiplicity": int(multiplicities[row]), + "dual": float(dual[row]) if not np.isnan(dual[row]) else None, + "slack": float(residual[row]), + } + for row in range(num_constraints) + ], + "y": [float(v) for v in y], + } + ) + + base_node_record.update( + { + "opt_h": total_opt_h, + "opt_infeasible": False, + "lp_calls": len(component_records), + "lp_time_sec": total_lp_time, + "total_lp_vars": total_lp_vars, + "total_lp_constraints": total_lp_constraints, + "num_components": len(component_records), + "num_graphlike_components": num_graphlike_components, + "graphlike_all_components": num_graphlike_components == len(component_records), + "max_support_size": max_support_size, + "max_component_variables": max_component_variables, + "max_component_constraints": max_component_constraints, + "positive_y_count": total_positive_y, + "tight_constraint_count": total_tight_constraints, + "positive_dual_count": total_positive_dual, + } + ) + + sample_record = None + if sample_raw_instance: + sample_record = { + "run_id": run_id, + "circuit": circuit_name, + "shot": shot_index, + "node_id": state.node_id, + "parent_node_id": state.parent_node_id, + "incoming_error_index": state.incoming_error_index, + "depth": state.depth, + "path_cost": state.path_cost, + "plain_h": plain_h, + "opt_h": total_opt_h, + "active_detector_ids": active_detector_ids.tolist(), + "components": sample_components, + } + + return LPProbeResult( + opt_h=total_opt_h, + node_record=base_node_record, + component_records=component_records, + sample_record=sample_record, + ) + + +def compute_node_metrics( + *, + run_id: str, + circuit_name: str, + shot_index: int, + state: SearchState, + data: DecoderData, + settings: SearchSettings, + sample_raw_instance: bool, +) -> LPProbeResult: + plain_h = plain_detcost_heuristic( + data=data, + active_detectors=state.active_detectors, + activated_error_mask=state.activated_error_mask, + blocked_errors=state.blocked_errors, + active_detector_counts=state.active_detector_counts, + respect_blocked_errors_in_heuristic=settings.respect_blocked_errors_in_heuristic, + ) + lp_probe = probe_opt_singleton_lp( + run_id=run_id, + circuit_name=circuit_name, + shot_index=shot_index, + state=state, + data=data, + settings=settings, + plain_h=plain_h, + sample_raw_instance=sample_raw_instance, + ) + return lp_probe + + +def observables_from_solution(data: DecoderData, activated_errors: Sequence[int]) -> np.ndarray: + observables = np.zeros(data.num_observables, dtype=bool) + for error_index in activated_errors: + for observable in data.error_observables[error_index]: + observables[observable] ^= True + return observables + + +def detectors_from_solution(data: DecoderData, activated_errors: Sequence[int]) -> np.ndarray: + detectors = np.zeros(data.num_detectors, dtype=bool) + for error_index in activated_errors: + for detector in data.error_detectors[error_index]: + detectors[detector] ^= True + return detectors + + +def heuristic_for_search(settings: SearchSettings, plain_h: float, opt_h: float) -> float: + if settings.search_heuristic == "plain": + return plain_h + if settings.search_heuristic == "opt": + return opt_h + raise ValueError(f"Unknown search heuristic: {settings.search_heuristic}") + + +def decode_and_probe_shot( + *, + run_id: str, + circuit_name: str, + shot_index: int, + shot_detections: np.ndarray, + shot_observables: np.ndarray, + data: DecoderData, + settings: SearchSettings, + logger: ProbeLogger, +) -> Dict[str, Any]: + shot_start = time.perf_counter() + sampler = NodeSampler(settings.sample_raw_nodes_per_shot) + aggregator = ShotAggregator() + + initial_active_detectors = np.asarray(shot_detections, dtype=bool).copy() + initial_counts = initial_detector_counts(data, initial_active_detectors) + initial_activated_mask = np.zeros(len(data.errors), dtype=bool) + initial_blocked = np.zeros(len(data.errors), dtype=bool) + + root_state = SearchState( + node_id=0, + parent_node_id=None, + incoming_error_index=None, + depth=0, + activated_errors=(), + activated_error_mask=initial_activated_mask, + blocked_errors=initial_blocked, + active_detectors=initial_active_detectors, + active_detector_counts=initial_counts, + path_cost=0.0, + search_h=0.0, + plain_h=0.0, + opt_h=0.0, + ) + + root_probe = compute_node_metrics( + run_id=run_id, + circuit_name=circuit_name, + shot_index=shot_index, + state=root_state, + data=data, + settings=settings, + sample_raw_instance=sampler.should_sample(root_state.node_id), + ) + root_state.plain_h = float(root_probe.node_record["plain_h"]) + root_state.opt_h = float(root_probe.node_record["opt_h"]) + root_state.search_h = heuristic_for_search(settings, root_state.plain_h, root_state.opt_h) + if root_state.search_h == INF: + raise RuntimeError( + f"Root node is infeasible for circuit={circuit_name} shot={shot_index}." + ) + + root_record = { + **root_probe.node_record, + "search_h": root_state.search_h, + "f_cost": root_state.path_cost + root_state.search_h, + "pushed": True, + } + logger.node_writer.write(root_record) + for component_record in root_probe.component_records: + logger.component_writer.write(component_record) + if root_probe.sample_record is not None: + logger.sample_writer.write(root_probe.sample_record) + aggregator.absorb_node(root_record) + + queue: List[Tuple[float, int, int, SearchState]] = [] + heapq_push_counter = 0 + npush = 1 + popped = 0 + max_queue_size = 1 + min_num_dets = int(initial_active_detectors.sum()) + max_num_dets = INF if settings.det_beam == INF else min_num_dets + settings.det_beam + heapq.heappush(queue, (root_state.path_cost + root_state.search_h, min_num_dets, heapq_push_counter, root_state)) + heapq_push_counter += 1 + next_node_id = 1 + + solution_state: Optional[SearchState] = None + status = "unknown" + + while queue: + max_queue_size = max(max_queue_size, len(queue)) + f_cost, num_dets, _, state = heapq.heappop(queue) + popped += 1 + + if settings.max_nodes_popped is not None and popped > settings.max_nodes_popped: + status = "max_nodes_popped" + break + + if num_dets > max_num_dets: + continue + + if settings.verbose_search: + print( + f"[{circuit_name} shot={shot_index}] nodes_popped={popped} pq={len(queue)} " + f"active_dets={num_dets} max_active_dets={max_num_dets} depth={state.depth} " + f"g={state.path_cost:.12g} h={state.search_h:.12g} f={f_cost:.12g}", + flush=True, + ) + + if num_dets == 0: + solution_state = state + status = "success" + break + + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = INF if settings.det_beam == INF else min_num_dets + settings.det_beam + + min_detector = int(np.flatnonzero(state.active_detectors)[0]) + blocked_prefix = state.blocked_errors.copy() + + for error_index in data.detector_to_errors[min_detector]: + blocked_prefix[error_index] = True + if state.blocked_errors[error_index]: + continue + + child_active_detectors, child_counts = apply_error( + data=data, + active_detectors=state.active_detectors, + active_detector_counts=state.active_detector_counts, + error_index=error_index, + ) + child_num_dets = int(child_active_detectors.sum()) + if child_num_dets > max_num_dets: + continue + + child_activated_mask = state.activated_error_mask.copy() + child_activated_mask[error_index] = True + child_blocked = blocked_prefix.copy() + child_path_cost = state.path_cost + float(data.error_costs[error_index]) + + child_state = SearchState( + node_id=next_node_id, + parent_node_id=state.node_id, + incoming_error_index=error_index, + depth=state.depth + 1, + activated_errors=state.activated_errors + (error_index,), + activated_error_mask=child_activated_mask, + blocked_errors=child_blocked, + active_detectors=child_active_detectors, + active_detector_counts=child_counts, + path_cost=child_path_cost, + search_h=0.0, + plain_h=0.0, + opt_h=0.0, + ) + next_node_id += 1 + + child_probe = compute_node_metrics( + run_id=run_id, + circuit_name=circuit_name, + shot_index=shot_index, + state=child_state, + data=data, + settings=settings, + sample_raw_instance=sampler.should_sample(child_state.node_id), + ) + child_state.plain_h = float(child_probe.node_record["plain_h"]) + child_state.opt_h = float(child_probe.node_record["opt_h"]) + child_state.search_h = heuristic_for_search(settings, child_state.plain_h, child_state.opt_h) + + pushed = child_state.search_h != INF + child_record = { + **child_probe.node_record, + "search_h": child_state.search_h, + "f_cost": child_state.path_cost + child_state.search_h, + "pushed": pushed, + } + logger.node_writer.write(child_record) + for component_record in child_probe.component_records: + logger.component_writer.write(component_record) + if child_probe.sample_record is not None: + logger.sample_writer.write(child_probe.sample_record) + aggregator.absorb_node(child_record) + + if not pushed: + continue + + heapq.heappush( + queue, + ( + child_state.path_cost + child_state.search_h, + child_num_dets, + heapq_push_counter, + child_state, + ), + ) + heapq_push_counter += 1 + npush += 1 + if settings.max_nodes_pushed is not None and npush > settings.max_nodes_pushed: + status = "max_nodes_pushed" + queue.clear() + break + + if status == "max_nodes_pushed": + break + + if status == "unknown": + status = "empty_queue" + + elapsed_seconds = time.perf_counter() - shot_start + predicted_observables: Optional[np.ndarray] = None + solution_cost: Optional[float] = None + observables_match: Optional[bool] = None + solution_size: Optional[int] = None + + if solution_state is not None: + reproduced_detectors = detectors_from_solution(data, solution_state.activated_errors) + if not np.array_equal(reproduced_detectors, shot_detections): + raise AssertionError( + f"Decoded error set does not reproduce the shot syndrome for circuit={circuit_name} shot={shot_index}." + ) + predicted_observables = observables_from_solution(data, solution_state.activated_errors) + observables_match = bool(np.array_equal(predicted_observables, shot_observables)) + solution_cost = float(solution_state.path_cost) + solution_size = len(solution_state.activated_errors) + + summary = { + "run_id": run_id, + "circuit": circuit_name, + "shot": shot_index, + **aggregator.finish(nodes_popped=popped, status=status, elapsed_seconds=elapsed_seconds), + "max_queue_size": max_queue_size, + "det_beam": settings.det_beam, + "search_heuristic": settings.search_heuristic, + "solution_cost": solution_cost, + "solution_size": solution_size, + "observables_match": observables_match, + "predicted_observables": (np.flatnonzero(predicted_observables).tolist() if predicted_observables is not None else None), + "sample_observables": np.flatnonzero(shot_observables).tolist(), + } + logger.shot_writer.write(summary) + logger.flush() + return summary + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Run the prototype decoder on several circuits and log detailed LP-structure data " + "for the optimal singleton heuristic." + ) + ) + parser.add_argument( + "circuits", + nargs="+", + type=Path, + help="Stim circuit files to analyze.", + ) + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="Directory where logs will be written.", + ) + parser.add_argument( + "--shots-per-circuit", + type=int, + default=10, + help="Number of sampled shots to decode per circuit (default: 10).", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Seed passed to stim.compile_detector_sampler(...).sample(...).", + ) + parser.add_argument( + "--det-beam", + type=parse_beam, + default=INF, + help="Beam cutoff on residual detector count. Use an integer or 'inf'.", + ) + parser.add_argument( + "--max-nodes-popped", + type=parse_optional_int, + default=5000, + help="Stop after this many popped nodes per shot (default: 5000; use 'none' for no limit).", + ) + parser.add_argument( + "--max-nodes-pushed", + type=parse_optional_int, + default=50000, + help="Stop after this many pushed nodes per shot (default: 50000; use 'none' for no limit).", + ) + parser.add_argument( + "--search-heuristic", + choices=["plain", "opt"], + default="opt", + help="Heuristic used for queue ordering. Both plain and optimal values are always logged.", + ) + parser.add_argument( + "--respect-blocked-errors-in-heuristic", + action="store_true", + help=( + "When set, both heuristics exclude precedence-blocked errors as well as already-activated errors. " + "By default, heuristics only exclude already-activated errors, matching the original prototype." + ), + ) + parser.add_argument( + "--merge-errors", + action=argparse.BooleanOptionalAction, + default=True, + help="Merge indistinguishable DEM errors before decoding (default: enabled).", + ) + parser.add_argument( + "--sample-raw-nodes-per-shot", + type=int, + default=25, + help="How many raw LP instances to dump per shot (default: 25).", + ) + parser.add_argument( + "--verbose-search", + action="store_true", + help="Print one line per popped node.", + ) + parser.add_argument( + "--quiet", + action="store_true", + help="Suppress per-shot progress printing.", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.shots_per_circuit <= 0: + parser.error("--shots-per-circuit must be positive.") + if args.sample_raw_nodes_per_shot < 0: + parser.error("--sample-raw-nodes-per-shot must be non-negative.") + + output_dir: Path = args.output_dir + output_dir.mkdir(parents=True, exist_ok=True) + run_id = f"singleton_lp_probe_{int(time.time())}" + + manifest = { + "run_id": run_id, + "argv": list(argv) if argv is not None else sys.argv[1:], + "circuits": [str(p) for p in args.circuits], + "shots_per_circuit": args.shots_per_circuit, + "seed": args.seed, + "det_beam": args.det_beam, + "max_nodes_popped": args.max_nodes_popped, + "max_nodes_pushed": args.max_nodes_pushed, + "search_heuristic": args.search_heuristic, + "respect_blocked_errors_in_heuristic": args.respect_blocked_errors_in_heuristic, + "merge_errors": args.merge_errors, + "sample_raw_nodes_per_shot": args.sample_raw_nodes_per_shot, + "lp_tol": LP_TOL, + "rational_tol": RATIONAL_TOL, + } + (output_dir / "manifest.json").write_text(json.dumps(manifest, indent=2, sort_keys=True) + "\n", encoding="utf-8") + + logger = ProbeLogger(output_dir) + settings = SearchSettings( + det_beam=args.det_beam, + search_heuristic=args.search_heuristic, + respect_blocked_errors_in_heuristic=args.respect_blocked_errors_in_heuristic, + max_nodes_popped=args.max_nodes_popped, + max_nodes_pushed=args.max_nodes_pushed, + sample_raw_nodes_per_shot=args.sample_raw_nodes_per_shot, + verbose_search=args.verbose_search, + ) + + try: + for circuit_path in args.circuits: + circuit = stim.Circuit.from_file(str(circuit_path)) + dem = circuit.detector_error_model(decompose_errors=False) + data = build_decoder_data(dem, merge_errors_in_dem=args.merge_errors) + dets_packed, obs_packed = circuit.compile_detector_sampler(seed=args.seed).sample( + shots=args.shots_per_circuit, + separate_observables=True, + bit_packed=True, + ) + detections = unpack_bit_packed_rows(dets_packed, count=dem.num_detectors) + observables = unpack_bit_packed_rows(obs_packed, count=dem.num_observables) + circuit_name = circuit_path.name + + for shot_index in range(args.shots_per_circuit): + if not args.quiet: + print( + f"[{run_id}] circuit={circuit_name} shot={shot_index} " + f"detectors={int(np.count_nonzero(detections[shot_index]))} ...", + flush=True, + ) + summary = decode_and_probe_shot( + run_id=run_id, + circuit_name=circuit_name, + shot_index=shot_index, + shot_detections=detections[shot_index], + shot_observables=observables[shot_index] if observables.size else np.zeros(0, dtype=bool), + data=data, + settings=settings, + logger=logger, + ) + if not args.quiet: + print( + f"[{run_id}] done circuit={circuit_name} shot={shot_index} status={summary['status']} " + f"nodes_popped={summary['nodes_popped']} nodes_created={summary['nodes_created']} " + f"total_lp_time_sec={summary['total_lp_time_sec']:.6f}", + flush=True, + ) + finally: + logger.close() + + if not args.quiet: + print(f"Wrote logs under {output_dir}", flush=True) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/multipass_beam_decoder.py b/src/py/astar/multipass_beam_decoder.py new file mode 100644 index 0000000..adb514c --- /dev/null +++ b/src/py/astar/multipass_beam_decoder.py @@ -0,0 +1,1189 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import heapq +import math +import shutil +import sys +import tempfile +import time +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import stim + + +STIM_RESULT_FORMATS = ("01", "b8", "r8", "ptb64", "hits", "dets") +STIM_RESULT_FORMATS_HELP = "/".join(STIM_RESULT_FORMATS) + + +@dataclass(frozen=True) +class Fault: + q: float + p: float + delta_scale: float + det_mask: int + likelihood_cost: float + + +@dataclass(frozen=True) +class DecoderModel: + faults: tuple[Fault, ...] + retiring_masks: tuple[int, ...] + live_masks_after: tuple[int, ...] + future_detcost: tuple[tuple[float, ...], ...] + all_possible_dets_mask: int + max_width: int + + +@dataclass(frozen=True) +class BeamDecodeResult: + predicted_logical: bool | None + certified: bool + margin: float + discarded_mass: float + max_width: int + elapsed_seconds: float + selected_pass: int = 1 + diagnostic_lines: tuple[str, ...] = () + + +@dataclass(frozen=True) +class DecodingShot: + det_mask: int + actual_logical: bool | None + + +@dataclass(frozen=True) +class ExperimentSummary: + predictions: list[bool | None] + num_certified: int + num_low_confidence: int + num_errors: int + num_truth_shots: int + num_scored_shots: int + total_elapsed: float + total_triggered: int + max_width_seen: int + + +HeuristicTables = tuple[dict[int, float], ...] +CandidateStates = tuple[tuple[int, ...], ...] + + +def _likelihood_cost(probability: float) -> float: + if probability <= 0.0: + return math.inf + if probability >= 1.0: + return 0.0 + return -math.log(probability / (1.0 - probability)) + + +def _detectors_from_mask(mask: int) -> list[int]: + detectors: list[int] = [] + while mask: + low_bit = mask & -mask + detectors.append(low_bit.bit_length() - 1) + mask ^= low_bit + return detectors + + +def _mask_from_bool_row(row: np.ndarray) -> int: + mask = 0 + for index in np.flatnonzero(row): + mask |= 1 << int(index) + return mask + + +def _future_detcost_by_detector(faults: tuple[Fault, ...], num_detectors: int) -> tuple[tuple[float, ...], ...]: + future_detcost: list[list[float]] = [[math.inf] * num_detectors for _ in range(len(faults) + 1)] + next_row = future_detcost[-1] + for fault_index in range(len(faults) - 1, -1, -1): + row = next_row.copy() + fault = faults[fault_index] + det_count = fault.det_mask.bit_count() + if det_count: + ecost = fault.likelihood_cost / det_count + for det_id in _detectors_from_mask(fault.det_mask): + if ecost < row[det_id]: + row[det_id] = ecost + future_detcost[fault_index] = row + next_row = row + return tuple(tuple(row) for row in future_detcost) + + +def _build_decoder_model(circuit: stim.Circuit) -> DecoderModel: + dem = circuit.detector_error_model(decompose_errors=False).flattened() + + faults: list[Fault] = [] + all_possible_dets_mask = 0 + last_seen_index: dict[int, int] = {} + + for inst in dem: + if inst.type != "error": + continue + + p = float(inst.args_copy()[0]) + det_mask = 0 + flip_l0 = 0 + for target in inst.targets_copy(): + if target.is_separator(): + continue + if target.is_relative_detector_id(): + det_mask ^= 1 << target.val + elif target.is_logical_observable_id() and target.val == 0: + flip_l0 ^= 1 + + faults.append( + Fault( + q=1.0 - p, + p=p, + delta_scale=(-p if flip_l0 else p), + det_mask=det_mask, + likelihood_cost=_likelihood_cost(p), + ) + ) + all_possible_dets_mask |= det_mask + + for det_id in _detectors_from_mask(det_mask): + last_seen_index[det_id] = len(faults) - 1 + + retiring_masks = [0] * len(faults) + for det_id, index in last_seen_index.items(): + retiring_masks[index] |= 1 << det_id + + live_masks_after = [0] * (len(faults) + 1) + active_mask = 0 + max_width = 0 + for i, fault in enumerate(faults): + active_mask |= fault.det_mask + max_width = max(max_width, active_mask.bit_count()) + active_mask &= ~retiring_masks[i] + live_masks_after[i + 1] = active_mask + + frozen_faults = tuple(faults) + return DecoderModel( + faults=frozen_faults, + retiring_masks=tuple(retiring_masks), + live_masks_after=tuple(live_masks_after), + future_detcost=_future_detcost_by_detector(frozen_faults, circuit.num_detectors), + all_possible_dets_mask=all_possible_dets_mask, + max_width=max_width, + ) + + +def _detcost_penalty(mismatch_mask: int, future_detcost: tuple[float, ...]) -> float: + total = 0.0 + pending = mismatch_mask + + while pending: + low_bit = pending & -pending + detector = low_bit.bit_length() - 1 + pending ^= low_bit + + best = future_detcost[detector] + if best == math.inf: + return math.inf + total += best + + return total + + +def _candidate_state_limit(beam: int) -> int: + # One pass feeds a slightly wider neighborhood to the next pass, while still + # capping memory for long circuits. + return max(1, min(128, 2 * beam)) + + +def _top_ranked_entries( + entries: list[tuple[float, float, int, float]], + limit: int, +) -> list[tuple[float, float, int, float]]: + if limit <= 0: + return [] + if len(entries) <= limit: + return sorted(entries, reverse=True) + return heapq.nlargest(limit, entries) + + +def _base_penalty_at_layer( + *, + model: DecoderModel, + live_target_masks: tuple[int, ...], + layer: int, + state: int, +) -> float: + mismatch_mask = state ^ live_target_masks[layer] + return _detcost_penalty(mismatch_mask=mismatch_mask, future_detcost=model.future_detcost[layer]) + + +def _lookup_existing_penalty( + *, + model: DecoderModel, + live_target_masks: tuple[int, ...], + layer: int, + state: int, + heuristic_tables: HeuristicTables | None, +) -> float: + penalty = _base_penalty_at_layer( + model=model, + live_target_masks=live_target_masks, + layer=layer, + state=state, + ) + if heuristic_tables is not None: + refined = heuristic_tables[layer].get(state) + if refined is not None and refined > penalty: + penalty = refined + return penalty + + +def _forward_beam_pass( + *, + model: DecoderModel, + actual_dets_mask: int, + live_target_masks: tuple[int, ...], + beam_width: int, + heuristic_tables: HeuristicTables | None, + collect_candidates: bool, + selected_pass: int, +) -> tuple[BeamDecodeResult, CandidateStates, dict[str, float]]: + beam = [(0, 1.0, 1.0)] + discarded_mass = 0.0 + candidate_limit = _candidate_state_limit(beam_width) if collect_candidates else 0 + + candidate_states_list: list[tuple[int, ...]] = [tuple() for _ in range(len(model.faults) + 1)] + candidate_states_list[0] = (0,) + + stats: dict[str, float] = { + "ranked_states_total": 0.0, + "candidate_states_total": 1.0, + "layers_pruned": 0.0, + "peak_ranked_states": 0.0, + "states_using_refined_lb": 0.0, + "states_blocked_by_refined": 0.0, + "finite_penalty_gain_hits": 0.0, + "total_penalty_uplift": 0.0, + "max_penalty_uplift": 0.0, + } + + for i, fault in enumerate(model.faults): + collapsed_probs: dict[int, list[float]] = {} + total_mass = 0.0 + retiring_mask = model.retiring_masks[i] + + if retiring_mask == 0: + for state, total, delta in beam: + absent_total = total * fault.q + absent_delta = delta * fault.q + total_mass += absent_total + entry = collapsed_probs.get(state) + if entry is None: + collapsed_probs[state] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + toggled = state ^ fault.det_mask + present_total = total * fault.p + present_delta = delta * fault.delta_scale + total_mass += present_total + entry = collapsed_probs.get(toggled) + if entry is None: + collapsed_probs[toggled] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + else: + expected_bits = actual_dets_mask & retiring_mask + keep_mask = ~retiring_mask + for state, total, delta in beam: + absent_total = total * fault.q + absent_delta = delta * fault.q + if (state & retiring_mask) == expected_bits: + shrunk = state & keep_mask + total_mass += absent_total + entry = collapsed_probs.get(shrunk) + if entry is None: + collapsed_probs[shrunk] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + toggled = state ^ fault.det_mask + present_total = total * fault.p + present_delta = delta * fault.delta_scale + if (toggled & retiring_mask) == expected_bits: + shrunk = toggled & keep_mask + total_mass += present_total + entry = collapsed_probs.get(shrunk) + if entry is None: + collapsed_probs[shrunk] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + + if total_mass == 0.0: + return ( + BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=discarded_mass, + max_width=model.max_width, + elapsed_seconds=0.0, + selected_pass=selected_pass, + ), + tuple(candidate_states_list), + stats, + ) + + ranked_states: list[tuple[float, float, int, float]] = [] + next_live_target_mask = live_target_masks[i + 1] + next_future_detcost = model.future_detcost[i + 1] + next_heuristics = None if heuristic_tables is None else heuristic_tables[i + 1] + + ranked_count = len(collapsed_probs) + stats["ranked_states_total"] += ranked_count + stats["peak_ranked_states"] = max(stats["peak_ranked_states"], float(ranked_count)) + + for state, (total, delta) in collapsed_probs.items(): + mismatch_mask = state ^ next_live_target_mask + base_penalty = _detcost_penalty(mismatch_mask=mismatch_mask, future_detcost=next_future_detcost) + penalty = base_penalty + + if next_heuristics is not None: + refined_penalty = next_heuristics.get(state) + if refined_penalty is not None and refined_penalty > penalty: + penalty = refined_penalty + stats["states_using_refined_lb"] += 1.0 + if refined_penalty == math.inf and base_penalty != math.inf: + stats["states_blocked_by_refined"] += 1.0 + elif refined_penalty != math.inf and base_penalty != math.inf: + uplift = refined_penalty - base_penalty + stats["finite_penalty_gain_hits"] += 1.0 + stats["total_penalty_uplift"] += uplift + stats["max_penalty_uplift"] = max(stats["max_penalty_uplift"], uplift) + + if penalty == math.inf: + rank_score = -math.inf + else: + rank_score = math.log(total) - penalty + ranked_states.append((rank_score, total, state, delta)) + + top_needed = max(beam_width, candidate_limit) + top_entries = _top_ranked_entries(ranked_states, top_needed) + + if collect_candidates: + candidate_slice = top_entries[:candidate_limit] + candidate_states_list[i + 1] = tuple(state for _, _, state, _ in candidate_slice) + stats["candidate_states_total"] += len(candidate_slice) + + dropped_mass = 0.0 + if len(ranked_states) > beam_width: + stats["layers_pruned"] += 1.0 + kept = top_entries[:beam_width] + kept_mass = sum(total for _, total, _, _ in kept) + dropped_mass = total_mass - kept_mass + else: + kept = top_entries + + inv_total_mass = 1.0 / total_mass + discarded_mass = (discarded_mass + dropped_mass) * inv_total_mass + beam = [ + (state, total * inv_total_mass, delta * inv_total_mass) + for _, total, state, delta in kept + ] + + candidate_states_list[-1] = (0,) + + _, _, final_delta = next((entry for entry in beam if entry[0] == 0), (0, 0.0, 0.0)) + margin = abs(final_delta) + certified = margin > discarded_mass + + if final_delta == 0.0: + return ( + BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=margin, + discarded_mass=discarded_mass, + max_width=model.max_width, + elapsed_seconds=0.0, + selected_pass=selected_pass, + ), + tuple(candidate_states_list), + stats, + ) + + return ( + BeamDecodeResult( + predicted_logical=final_delta < 0.0, + certified=certified, + margin=margin, + discarded_mass=discarded_mass, + max_width=model.max_width, + elapsed_seconds=0.0, + selected_pass=selected_pass, + ), + tuple(candidate_states_list), + stats, + ) + + +def _build_refined_lower_bounds( + *, + model: DecoderModel, + actual_dets_mask: int, + live_target_masks: tuple[int, ...], + candidate_states: CandidateStates, + existing_tables: HeuristicTables | None, +) -> tuple[HeuristicTables, dict[str, float]]: + tables_list: list[dict[int, float]] = [dict() for _ in range(len(model.faults) + 1)] + tables_list[-1][0] = 0.0 + + stats: dict[str, float] = { + "candidate_states_total": float(sum(len(layer) for layer in candidate_states)), + "states_evaluated": 1.0, + "layers_with_candidates": 1.0, + "exact_successor_hits": 0.0, + "prior_successor_hits": 0.0, + "base_successor_hits": 0.0, + "states_improved": 0.0, + "states_ruled_out": 0.0, + "finite_gain_hits": 0.0, + "total_lb_gain": 0.0, + "max_lb_gain": 0.0, + } + + for i in range(len(model.faults) - 1, -1, -1): + states_here = candidate_states[i] + if not states_here: + continue + + stats["layers_with_candidates"] += 1.0 + fault = model.faults[i] + retiring_mask = model.retiring_masks[i] + expected_bits = actual_dets_mask & retiring_mask + keep_mask = ~retiring_mask + next_refined = tables_list[i + 1] + current_refined = tables_list[i] + + for state in states_here: + best = math.inf + + if (state & retiring_mask) == expected_bits: + next_state = state & keep_mask + successor_penalty = _lookup_existing_penalty( + model=model, + live_target_masks=live_target_masks, + layer=i + 1, + state=next_state, + heuristic_tables=existing_tables, + ) + exact_successor = next_refined.get(next_state) + if exact_successor is not None and exact_successor > successor_penalty: + successor_penalty = exact_successor + stats["exact_successor_hits"] += 1.0 + elif existing_tables is not None and next_state in existing_tables[i + 1]: + stats["prior_successor_hits"] += 1.0 + else: + stats["base_successor_hits"] += 1.0 + best = min(best, successor_penalty) + + toggled = state ^ fault.det_mask + if (toggled & retiring_mask) == expected_bits: + next_state = toggled & keep_mask + successor_penalty = _lookup_existing_penalty( + model=model, + live_target_masks=live_target_masks, + layer=i + 1, + state=next_state, + heuristic_tables=existing_tables, + ) + exact_successor = next_refined.get(next_state) + if exact_successor is not None and exact_successor > successor_penalty: + successor_penalty = exact_successor + stats["exact_successor_hits"] += 1.0 + elif existing_tables is not None and next_state in existing_tables[i + 1]: + stats["prior_successor_hits"] += 1.0 + else: + stats["base_successor_hits"] += 1.0 + best = min(best, fault.likelihood_cost + successor_penalty) + + old_penalty = _lookup_existing_penalty( + model=model, + live_target_masks=live_target_masks, + layer=i, + state=state, + heuristic_tables=existing_tables, + ) + new_penalty = best if best > old_penalty else old_penalty + current_refined[state] = new_penalty + stats["states_evaluated"] += 1.0 + + if new_penalty > old_penalty: + stats["states_improved"] += 1.0 + if new_penalty == math.inf: + stats["states_ruled_out"] += 1.0 + elif old_penalty != math.inf: + gain = new_penalty - old_penalty + stats["finite_gain_hits"] += 1.0 + stats["total_lb_gain"] += gain + stats["max_lb_gain"] = max(stats["max_lb_gain"], gain) + + return tuple(tables_list), stats + + +def _result_confidence_key(result: BeamDecodeResult) -> tuple[float, ...]: + return ( + float(int(result.certified)), + float(int(result.predicted_logical is not None)), + result.margin - result.discarded_mass, + result.margin, + -result.discarded_mass, + float(result.selected_pass), + ) + + +def _format_forward_summary( + *, + shot_index: int | None, + pass_index: int, + num_passes: int, + beam_width: int, + candidate_limit: int, + stats: dict[str, float], + result: BeamDecodeResult, +) -> str: + refined_hits = int(stats["states_using_refined_lb"]) + finite_gain_hits = int(stats["finite_penalty_gain_hits"]) + avg_gain = stats["total_penalty_uplift"] / max(1, finite_gain_hits) + return ( + f"multipass shot={shot_index} pass={pass_index}/{num_passes} phase=forward " + f"beam={beam_width} candidate_limit={candidate_limit} ranked_states={int(stats['ranked_states_total'])} " + f"peak_layer_states={int(stats['peak_ranked_states'])} layers_pruned={int(stats['layers_pruned'])} " + f"refined_hits={refined_hits} refined_blocks={int(stats['states_blocked_by_refined'])} " + f"avg_penalty_gain={avg_gain:.6f} max_penalty_gain={stats['max_penalty_uplift']:.6f} " + f"prediction={result.predicted_logical} certified={result.certified} " + f"margin={result.margin:.6e} discarded_mass={result.discarded_mass:.6e}" + ) + + +def _format_backward_summary( + *, + shot_index: int | None, + pass_index: int, + num_passes: int, + stats: dict[str, float], +) -> str: + finite_gain_hits = int(stats["finite_gain_hits"]) + avg_gain = stats["total_lb_gain"] / max(1, finite_gain_hits) + return ( + f"multipass shot={shot_index} pass={pass_index}/{num_passes} phase=backward " + f"candidate_states={int(stats['candidate_states_total'])} states_evaluated={int(stats['states_evaluated'])} " + f"layers_with_candidates={int(stats['layers_with_candidates'])} improved_states={int(stats['states_improved'])} " + f"ruled_out={int(stats['states_ruled_out'])} avg_lb_gain={avg_gain:.6f} " + f"max_lb_gain={stats['max_lb_gain']:.6f} successor_hits=(exact:{int(stats['exact_successor_hits'])}," + f"prior:{int(stats['prior_successor_hits'])},base:{int(stats['base_successor_hits'])})" + ) + + +def _format_selection_summary( + *, + shot_index: int | None, + chosen: BeamDecodeResult, + num_passes: int, +) -> str: + confidence_gap = chosen.margin - chosen.discarded_mass + return ( + f"multipass shot={shot_index} selection chosen_pass={chosen.selected_pass}/{num_passes} " + f"prediction={chosen.predicted_logical} certified={chosen.certified} " + f"confidence_gap={confidence_gap:.6e} margin={chosen.margin:.6e} " + f"discarded_mass={chosen.discarded_mass:.6e}" + ) + + +def _as_bool_2d(data: np.ndarray, *, expected_cols: int, description: str) -> np.ndarray: + arr = np.asarray(data) + if arr.ndim != 2: + raise ValueError(f"Expected {description} to be a 2D array but got shape {arr.shape!r}.") + if arr.shape[1] != expected_cols: + raise ValueError( + f"Expected {description} to have {expected_cols} columns but got {arr.shape[1]}." + ) + if arr.dtype != np.bool_: + arr = arr.astype(np.bool_, copy=False) + return arr + + +def _sample_shot_arrays( + circuit: stim.Circuit, + *, + shots: int, + seed: int | None, +) -> tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets, obs = sampler.sample(shots=shots, separate_observables=True) + return ( + _as_bool_2d(dets, expected_cols=circuit.num_detectors, description="sampled detector data"), + _as_bool_2d(obs, expected_cols=circuit.num_observables, description="sampled observable data"), + ) + + +def _read_detector_shot_arrays( + *, + path: str, + fmt: str, + num_detectors: int, + num_observables: int, +) -> tuple[np.ndarray, np.ndarray | None]: + flat = stim.read_shot_data_file( + path=path, + format=fmt, + bit_packed=False, + num_measurements=0, + num_detectors=num_detectors, + num_observables=num_observables, + ) + + expected_cols = num_detectors + num_observables + flat = _as_bool_2d( + flat, + expected_cols=expected_cols, + description="combined detector/observable input data", + ) + if num_observables: + return flat[:, :num_detectors], flat[:, num_detectors:] + return flat, None + + +def _read_observable_shot_array(*, path: str, fmt: str, num_observables: int) -> np.ndarray: + obs = stim.read_shot_data_file( + path=path, + format=fmt, + bit_packed=False, + num_measurements=0, + num_detectors=0, + num_observables=num_observables, + ) + return _as_bool_2d(obs, expected_cols=num_observables, description="observable input data") + + +def _apply_shot_range( + dets: np.ndarray, + obs: np.ndarray | None, + *, + shot_range_begin: int, + shot_range_end: int, +) -> tuple[np.ndarray, np.ndarray | None]: + if not (shot_range_begin or shot_range_end): + return dets, obs + + if shot_range_end < shot_range_begin: + raise ValueError("Provided shot range must satisfy --shot-range-end >= --shot-range-begin.") + if shot_range_end > len(dets): + raise ValueError( + f"Shot range end {shot_range_end} is past the end of the shot data (size {len(dets)})." + ) + + dets = dets[shot_range_begin:shot_range_end] + if obs is not None: + obs = obs[shot_range_begin:shot_range_end] + return dets, obs + + +def _shots_from_arrays(dets: np.ndarray, obs: np.ndarray | None) -> list[DecodingShot]: + shots: list[DecodingShot] = [] + for shot_index in range(dets.shape[0]): + actual_logical = None if obs is None else bool(obs[shot_index, 0]) + shots.append( + DecodingShot( + det_mask=_mask_from_bool_row(dets[shot_index]), + actual_logical=actual_logical, + ) + ) + return shots + + +def _resolve_stdin_path_if_needed(path: str, *, temp_dir: str, stem: str) -> str: + if path != "-": + return path + temp_path = str(Path(temp_dir) / f"{stem}.bin") + with open(temp_path, "wb") as f: + f.write(sys.stdin.buffer.read()) + return temp_path + + +def _resolve_stdout_path_if_needed(path: str, *, temp_dir: str, stem: str) -> tuple[str, bool]: + if path != "-": + return path, False + return str(Path(temp_dir) / f"{stem}.bin"), True + + +def _copy_file_to_stdout(path: str) -> None: + sys.stdout.flush() + with open(path, "rb") as f: + shutil.copyfileobj(f, sys.stdout.buffer) + sys.stdout.buffer.flush() + + +def _load_shots( + circuit: stim.Circuit, + args: argparse.Namespace, + *, + temp_dir: str, +) -> list[DecodingShot]: + if args.in_file: + in_path = _resolve_stdin_path_if_needed(args.in_file, temp_dir=temp_dir, stem="shots_in") + appended_obs_count = circuit.num_observables if args.in_includes_appended_observables else 0 + dets, obs = _read_detector_shot_arrays( + path=in_path, + fmt=args.in_format, + num_detectors=circuit.num_detectors, + num_observables=appended_obs_count, + ) + + if args.obs_in_file: + obs_in_path = _resolve_stdin_path_if_needed(args.obs_in_file, temp_dir=temp_dir, stem="obs_in") + obs = _read_observable_shot_array( + path=obs_in_path, + fmt=args.obs_in_format, + num_observables=circuit.num_observables, + ) + if len(obs) != len(dets): + raise ValueError("Observable input ended before, or after, the detector shot data.") + else: + dets, obs = _sample_shot_arrays(circuit, shots=args.sample_num_shots, seed=args.sample_seed) + + dets, obs = _apply_shot_range( + dets, + obs, + shot_range_begin=args.shot_range_begin, + shot_range_end=args.shot_range_end, + ) + return _shots_from_arrays(dets, obs) + + +def decode_beam_search_detcost_ranked( + model: DecoderModel, + actual_dets_mask: int, + L: int, + *, + num_passes: int = 1, + shot_index: int | None = None, +) -> BeamDecodeResult: + start_time = time.perf_counter() + + if (actual_dets_mask & ~model.all_possible_dets_mask) != 0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=0.0, + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + selected_pass=1, + ) + + live_target_masks = tuple(actual_dets_mask & mask for mask in model.live_masks_after) + candidate_limit = _candidate_state_limit(L) + current_tables: HeuristicTables | None = None + per_pass_results: list[BeamDecodeResult] = [] + diagnostic_lines: list[str] = [] + + for pass_index in range(1, num_passes + 1): + collect_candidates = pass_index < num_passes + pass_result, candidate_states, forward_stats = _forward_beam_pass( + model=model, + actual_dets_mask=actual_dets_mask, + live_target_masks=live_target_masks, + beam_width=L, + heuristic_tables=current_tables, + collect_candidates=collect_candidates, + selected_pass=pass_index, + ) + per_pass_results.append(pass_result) + + if num_passes > 1: + diagnostic_lines.append( + _format_forward_summary( + shot_index=shot_index, + pass_index=pass_index, + num_passes=num_passes, + beam_width=L, + candidate_limit=candidate_limit, + stats=forward_stats, + result=pass_result, + ) + ) + + if collect_candidates: + current_tables, backward_stats = _build_refined_lower_bounds( + model=model, + actual_dets_mask=actual_dets_mask, + live_target_masks=live_target_masks, + candidate_states=candidate_states, + existing_tables=current_tables, + ) + if num_passes > 1: + diagnostic_lines.append( + _format_backward_summary( + shot_index=shot_index, + pass_index=pass_index, + num_passes=num_passes, + stats=backward_stats, + ) + ) + + chosen = max(per_pass_results, key=_result_confidence_key) + if num_passes > 1: + diagnostic_lines.append( + _format_selection_summary( + shot_index=shot_index, + chosen=chosen, + num_passes=num_passes, + ) + ) + + return BeamDecodeResult( + predicted_logical=chosen.predicted_logical, + certified=chosen.certified, + margin=chosen.margin, + discarded_mass=chosen.discarded_mass, + max_width=chosen.max_width, + elapsed_seconds=time.perf_counter() - start_time, + selected_pass=chosen.selected_pass, + diagnostic_lines=tuple(diagnostic_lines), + ) + + +def _print_run_header( + *, + circuit: stim.Circuit, + args: argparse.Namespace, + num_shots: int, + log_stream, +) -> None: + print(f"Running on circuit {args.circuit}", file=log_stream) + print(f"Total Detectors: {circuit.num_detectors}", file=log_stream) + print(f"Total Observables: {circuit.num_observables}", file=log_stream) + if args.in_file: + print(f"Shot Input: {args.in_file}", file=log_stream) + print(f"Shot Input Format: {args.in_format}", file=log_stream) + if args.in_includes_appended_observables: + print("Observable Input: appended to --in", file=log_stream) + elif args.obs_in_file: + print(f"Observable Input: {args.obs_in_file}", file=log_stream) + print(f"Observable Format: {args.obs_in_format}", file=log_stream) + else: + print("Observable Input: none", file=log_stream) + else: + print(f"Sample Seed: {args.sample_seed}", file=log_stream) + print(f"Requested Shots: {args.sample_num_shots}", file=log_stream) + if args.shot_range_begin or args.shot_range_end: + print( + f"Shot Range: [{args.shot_range_begin}, {args.shot_range_end})", + file=log_stream, + ) + print(f"Beam: {args.beam}", file=log_stream) + print(f"Num Passes: {args.num_passes}", file=log_stream) + if args.num_passes > 1: + print(f"Pass Candidate Limit: {_candidate_state_limit(args.beam)}", file=log_stream) + print( + "Pass Logic: forward beam -> candidate residual states -> backward Bellman lower bounds -> choose best-confidence pass", + file=log_stream, + ) + print(f"Num Shots: {num_shots}", file=log_stream) + + +def run_experiment(args: argparse.Namespace) -> ExperimentSummary: + circuit = stim.Circuit.from_file(args.circuit) + if circuit.num_observables != 1: + raise ValueError( + "This decoder currently supports exactly one logical observable, because it only tracks L0. " + f"The circuit has {circuit.num_observables} observables." + ) + + model = _build_decoder_model(circuit) + log_stream = sys.stderr if args.out_file == "-" else sys.stdout + + with tempfile.TemporaryDirectory() as temp_dir: + shots = _load_shots(circuit, args, temp_dir=temp_dir) + _print_run_header(circuit=circuit, args=args, num_shots=len(shots), log_stream=log_stream) + + num_errors = 0 + num_low_confidence = 0 + num_certified = 0 + num_truth_shots = 0 + num_scored_shots = 0 + total_elapsed = 0.0 + total_triggered = 0 + max_width_seen = 0 + predictions: list[bool | None] = [] + selected_pass_counts = [0] * (args.num_passes + 1) + + detailed_multipass_for_all = args.print_per_shot or len(shots) <= 10 + + for shot_index, shot in enumerate(shots): + result = decode_beam_search_detcost_ranked( + model, + shot.det_mask, + args.beam, + num_passes=args.num_passes, + shot_index=shot_index, + ) + predictions.append(result.predicted_logical) + selected_pass_counts[result.selected_pass] += 1 + + if result.diagnostic_lines: + if detailed_multipass_for_all or shot_index == 0: + for line in result.diagnostic_lines: + print(line, file=log_stream) + else: + print(result.diagnostic_lines[-1], file=log_stream) + + success: bool | None + if shot.actual_logical is None or result.predicted_logical is None: + success = None + else: + success = result.predicted_logical == shot.actual_logical + + if result.predicted_logical is None: + num_low_confidence += 1 + if shot.actual_logical is not None: + num_truth_shots += 1 + if success is not None: + num_scored_shots += 1 + if not success: + num_errors += 1 + if result.certified: + num_certified += 1 + + total_elapsed += result.elapsed_seconds + triggered_dets = shot.det_mask.bit_count() + total_triggered += triggered_dets + max_width_seen = max(max_width_seen, result.max_width) + + shots_done = shot_index + 1 + error_rate_so_far = num_errors / num_scored_shots if num_scored_shots else 0.0 + progress_line = ( + f"progress shots_done={shots_done}/{len(shots)} errors_so_far={num_errors} " + f"low_conf_so_far={num_low_confidence} scored_shots_so_far={num_scored_shots} " + f"error_rate_so_far={error_rate_so_far:.6f} elapsed_total_seconds={total_elapsed:.6f}" + ) + if args.num_passes > 1: + progress_line += f" selected_pass={result.selected_pass}" + print(progress_line, file=log_stream) + + if args.print_per_shot: + print( + f"shot={shot_index} triggered_detectors={triggered_dets} " + f"predicted_logical={result.predicted_logical} actual_logical={shot.actual_logical} " + f"success={success} certified={result.certified} selected_pass={result.selected_pass} " + f"margin={result.margin:.6e} discarded_mass={result.discarded_mass:.6e} " + f"elapsed_seconds={result.elapsed_seconds:.6f}", + file=log_stream, + ) + + if args.out_file: + output_path, copy_to_stdout = _resolve_stdout_path_if_needed( + args.out_file, + temp_dir=temp_dir, + stem="predictions_out", + ) + prediction_data = np.zeros((len(predictions), circuit.num_observables), dtype=np.bool_) + for shot_index, predicted_logical in enumerate(predictions): + prediction_data[shot_index, 0] = bool(predicted_logical) if predicted_logical is not None else False + + if args.out_format == "ptb64" and len(prediction_data) % 64 != 0: + raise ValueError("The ptb64 format requires the number of shots to be a multiple of 64.") + + stim.write_shot_data_file( + data=prediction_data, + path=output_path, + format=args.out_format, + num_measurements=0, + num_detectors=0, + num_observables=circuit.num_observables, + ) + if copy_to_stdout: + _copy_file_to_stdout(output_path) + if num_low_confidence: + print( + f"warning: wrote {num_low_confidence} low-confidence predictions as L0=0 because Stim result " + "files can only store bits, not unknown values.", + file=log_stream, + ) + + print(f"Mean Triggered Dets: {total_triggered / max(1, len(shots)):.2f}", file=log_stream) + print(f"Max Width: {max_width_seen}", file=log_stream) + print(f"Certified Shots: {num_certified}", file=log_stream) + print(f"Low Confidence: {num_low_confidence}", file=log_stream) + print(f"Truth-Labeled Shots: {num_truth_shots}", file=log_stream) + print(f"Scored Shots: {num_scored_shots}", file=log_stream) + if args.num_passes > 1: + selected_summary = " ".join( + f"P{pass_index}={count}" + for pass_index, count in enumerate(selected_pass_counts[1:], start=1) + ) + print(f"Selected Passes: {selected_summary}", file=log_stream) + if num_truth_shots: + print(f"Logical Errors: {num_errors}", file=log_stream) + else: + print("Logical Errors: n/a", file=log_stream) + print(f"Total Seconds: {total_elapsed:.6f}", file=log_stream) + print(f"Mean Seconds/Shot: {total_elapsed / max(1, len(shots)):.6f}", file=log_stream) + + return ExperimentSummary( + predictions=predictions, + num_certified=num_certified, + num_low_confidence=num_low_confidence, + num_errors=num_errors, + num_truth_shots=num_truth_shots, + num_scored_shots=num_scored_shots, + total_elapsed=total_elapsed, + total_triggered=total_triggered, + max_width_seen=max_width_seen, + ) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Run trellis beam decoding ranked by mass minus a detcost-style future penalty, " + "optionally refined by multi-pass candidate-state Bellman backups, " + "with Stim-compatible shot-data I/O options." + ), + allow_abbrev=False, + ) + parser.add_argument("--circuit", required=True, help="Path to the .stim circuit file.") + parser.add_argument("--beam", type=int, default=1000, help="Beam width cutoff.") + parser.add_argument( + "--num-passes", + type=int, + default=1, + help=( + "Number of forward/backward refinement passes. 1 reproduces the original single-pass beam search. " + "Larger values reuse beam states from one pass to sharpen the remaining-cost estimates of the next." + ), + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=None, + help="Number of sampled shots. Defaults to 1 unless --in is provided.", + ) + parser.add_argument("--sample-seed", type=int, default=None, help="Stim sampler seed.") + parser.add_argument( + "--shot-range-begin", + type=int, + default=0, + help=( + "If both --shot-range-begin and --shot-range-end are 0, decode all available shots. " + "Otherwise only decode shots in [begin, end)." + ), + ) + parser.add_argument( + "--shot-range-end", + type=int, + default=0, + help=( + "If both --shot-range-begin and --shot-range-end are 0, decode all available shots. " + "Otherwise only decode shots in [begin, end)." + ), + ) + parser.add_argument( + "--in", + dest="in_file", + default="", + help="File to read detection events from (use - for stdin).", + ) + parser.add_argument( + "--in-format", + "--in_format", + dest="in_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file read by --in ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--in-includes-appended-observables", + "--in_includes_appended_observables", + dest="in_includes_appended_observables", + action="store_true", + help="Assume the observable flips are appended to each shot in --in.", + ) + parser.add_argument( + "--obs-in", + "--obs_in", + dest="obs_in_file", + default="", + help="File to read observable flips from (use - for stdin).", + ) + parser.add_argument( + "--obs-in-format", + "--obs_in_format", + dest="obs_in_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file read by --obs-in ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--out", + dest="out_file", + default="", + help="File to write predicted observable flips to (use - for stdout).", + ) + parser.add_argument( + "--out-format", + "--out_format", + dest="out_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file written by --out ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--print-per-shot", + action="store_true", + help="Print a detailed line per decoded shot.", + ) + args = parser.parse_args() + + if args.sample_num_shots is None: + # Preserve the original script's one-shot default while still allowing + # file input without requiring --sample-num-shots 0. + args.sample_num_shots = 0 if args.in_file else 1 + + if args.beam <= 0: + raise ValueError("--beam must be positive.") + if args.num_passes <= 0: + raise ValueError("--num-passes must be positive.") + if args.sample_num_shots < 0: + raise ValueError("--sample-num-shots must be non-negative.") + if args.sample_seed is not None and args.sample_seed < 0: + raise ValueError("--sample-seed must be non-negative.") + if args.shot_range_begin < 0 or args.shot_range_end < 0: + raise ValueError("--shot-range-begin and --shot-range-end must be non-negative.") + if args.shot_range_end < args.shot_range_begin: + raise ValueError("Provided shot range must satisfy --shot-range-end >= --shot-range-begin.") + if args.in_includes_appended_observables and args.obs_in_file: + raise ValueError( + "Choose either --in-includes-appended-observables or --obs-in, not both." + ) + if args.obs_in_file and not args.in_file: + raise ValueError("Cannot load observable flips from --obs-in without also providing --in.") + if args.in_file == "-" and args.obs_in_file == "-": + raise ValueError("At most one of --in and --obs-in may read from stdin.") + + num_shot_sources = int(args.sample_num_shots > 0) + int(bool(args.in_file)) + if num_shot_sources != 1: + raise ValueError("Requires exactly one source of shots: either --sample-num-shots > 0 or --in.") + + return args + + +if __name__ == "__main__": + run_experiment(_parse_args()) diff --git a/src/py/astar/plot_log.py b/src/py/astar/plot_log.py new file mode 100644 index 0000000..56e7687 --- /dev/null +++ b/src/py/astar/plot_log.py @@ -0,0 +1,156 @@ +import sys +import os +import matplotlib.pyplot as plt +import numpy as np + +def analyze_log(filename): + min_masses = [] + errors = [] + + current_errors = 0 + current_low_conf = 0 + + pending_error_diff = None + pending_low_conf_diff = None + + # Parse the log file line by line + with open(filename, 'r') as f: + for line in f: + parts = line.split() + if not parts: + continue + + if parts[0] == "num_shots": + # Find 'num_errors' and 'num_low_confidence' and grab the values + idx_err = parts.index("num_errors") + errs = int(parts[idx_err + 2]) + + idx_lc = parts.index("num_low_confidence") + lc = int(parts[idx_lc + 2]) + + # Calculate diffs for this specific shot + pending_error_diff = errs - current_errors + pending_low_conf_diff = lc - current_low_conf + + current_errors = errs + current_low_conf = lc + + elif parts[0] == "branch_masses": + obs0 = float(parts[1].split("=")[1]) + obs1 = float(parts[2].split("=")[1]) + + # Override if it was flagged as a low confidence shot + if pending_low_conf_diff is not None and pending_low_conf_diff > 0: + obs0 = 0.5 + obs1 = 0.5 + # Count the low confidence increment as additional logical errors + pending_error_diff += pending_low_conf_diff + else: + norm = obs0 + obs1 + if norm == 0: + obs0 = 0.5 + obs1 = 0.5 + else: + obs0 /= norm + obs1 /= norm + + # Only append if we just successfully parsed a num_shots line + if pending_error_diff is not None: + min_masses.append(min(obs0, obs1)) + errors.append(pending_error_diff) + + # Reset pending diffs to ensure we don't double-count + pending_error_diff = None + pending_low_conf_diff = None + + min_masses = np.array(min_masses) + errors = np.array(errors) + + if len(min_masses) == 0: + print("No valid shot data found in the file.") + return + + # To calculate how error rates change based on our cutoff, + # we sort the shots from most certain (lowest min_mass) to least certain. + sorted_idx = np.argsort(min_masses) + sorted_masses = min_masses[sorted_idx] + sorted_errors = errors[sorted_idx] + + N = len(sorted_masses) + + # K represents the number of shots we *accept* (1 to N) + K_arr = np.arange(1, N + 1) + + # Cumulative errors in the accepted subset of shots + accepted_errors = np.cumsum(sorted_errors) + + # Error rate = (errors in accepted subset) / (number of accepted shots) + error_rates = accepted_errors / K_arr + + # Rejection rate = (number of rejected shots) / (total shots) + rejection_rates = (N - K_arr) / N + + # ------------------ + # Pre-process for Log Scale Histogram + # ------------------ + # Find the smallest non-zero mass. If everything is 0, default to 1e-10 + if np.any(min_masses > 0): + min_nonzero = np.min(min_masses[min_masses > 0]) + # Set exact 0s to half the minimum non-zero value so they fall in the leftmost bin + epsilon = min_nonzero / 2.0 + else: + epsilon = 1e-10 + + # Replace 0s with epsilon + masses_for_hist = np.where(min_masses == 0, epsilon, min_masses) + + # Safely get max mass to define bin edges + max_mass = np.max(masses_for_hist) + if max_mass == epsilon: + max_mass = epsilon * 10 # Fallback in case all values were 0 + + # Generate 50 logarithmically spaced bins + log_bins = np.logspace(np.log10(epsilon), np.log10(max_mass), 50) + + # ------------------ + # Create the Figures + # ------------------ + fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + + # Plot 1: Distribution of min masses (Log Scale X) + axes[0].hist(masses_for_hist, bins=log_bins, color='skyblue', edgecolor='black') + axes[0].set_xscale('log') + axes[0].set_xlabel('Min Mass (Log Scale, 0s in leftmost bin)') + axes[0].set_ylabel('Frequency') + axes[0].set_title('Distribution of Min Masses') + + # Plot 2: Logical error rate vs Min Mass Cutoff + axes[1].plot(sorted_masses, error_rates, color='purple', lw=2) + axes[1].set_xlabel('Min Mass Cutoff (Threshold)') + axes[1].set_ylabel('Logical Error Rate (Accepted Shots)') + axes[1].set_title('Error Rate vs Min Mass Cutoff') + axes[1].grid(True, linestyle='--', alpha=0.7) + + # Plot 3: Logical error rate vs Rejection rate + axes[2].plot(rejection_rates, error_rates, color='red', lw=2) + axes[2].set_xlabel('Rejection Rate') + axes[2].set_ylabel('Logical Error Rate (Accepted Shots)') + axes[2].set_title('Error Rate vs Rejection Rate') + axes[2].grid(True, linestyle='--', alpha=0.7) + axes[2].set_xlim(0, 1) + + plt.tight_layout() + + # Generate output filename based on input filename + base_name = os.path.splitext(os.path.basename(filename))[0] + out_filename = f"{base_name}_analysis.png" + + # Save to disk instead of displaying + plt.savefig(out_filename, dpi=300, bbox_inches='tight') + print(f"Success! Plot saved to disk as: {out_filename}") + +if __name__ == "__main__": + if len(sys.argv) < 2: + print(f"Usage: python {sys.argv[0]} ") + else: + analyze_log(sys.argv[1]) diff --git a/src/py/astar/trellis_beam.py b/src/py/astar/trellis_beam.py new file mode 100644 index 0000000..1b69a96 --- /dev/null +++ b/src/py/astar/trellis_beam.py @@ -0,0 +1,250 @@ +import argparse +import heapq +import time +from dataclasses import dataclass +from operator import itemgetter + +import stim + + +@dataclass(frozen=True) +class BeamDecodeResult: + predicted_logical: bool | None + certified: bool + margin: float + discarded_mass: float + max_width: int + elapsed_seconds: float + + +def decode_beam_search(circuit: stim.Circuit, actual_dets: set[int], L: int) -> BeamDecodeResult: + """ + Decodes a syndrome using a dynamic programming sweep with a Top-L beam cutoff. + """ + start_time = time.perf_counter() + + # 1. Extract the Detector Error Model (flattened, decompose_errors=False) + dem = circuit.detector_error_model(decompose_errors=False).flattened() + + # 2. Parse the DEM into a list of faults + faults = [] + all_possible_dets_mask = 0 + + for inst in dem: + if inst.type != "error": + continue + + p = inst.args_copy()[0] + det_mask = 0 + flip_l0 = 0 + + for t in inst.targets_copy(): + if t.is_separator(): + continue + if t.is_relative_detector_id(): + det_mask ^= (1 << t.val) + elif t.is_logical_observable_id() and t.val == 0: + flip_l0 ^= 1 + + q = 1.0 - p + delta_scale = -p if flip_l0 else p + faults.append((q, p, delta_scale, det_mask)) + all_possible_dets_mask |= det_mask + + # 3. Convert observed syndrome set to an integer bitmask + actual_dets_mask = 0 + for d in actual_dets: + actual_dets_mask ^= (1 << d) + + # If the quantum computer triggered a detector that our error model says + # is mathematically impossible to trigger, the syndrome is invalid. + if (actual_dets_mask & ~all_possible_dets_mask) != 0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=0.0, + max_width=0, + elapsed_seconds=time.perf_counter() - start_time, + ) + + # 4. Pre-calculate retirement schedules + retiring_masks = [0] * len(faults) + last_seen_index = {} + + for idx, (_, _, _, det_mask) in enumerate(faults): + temp = det_mask + d_id = 0 + while temp > 0: + if temp & 1: + last_seen_index[d_id] = idx + temp >>= 1 + d_id += 1 + + for d_id, idx in last_seen_index.items(): + retiring_masks[idx] |= (1 << d_id) + + active_mask = 0 + max_width = 0 + for i, (_, _, _, det_mask) in enumerate(faults): + active_mask |= det_mask + max_width = max(max_width, active_mask.bit_count()) + active_mask &= ~retiring_masks[i] + + # 5. The Beam Search Sweep + beam = [(0, 1.0, 1.0)] + discarded_mass = 0.0 + + for i, (q, p, delta_scale, det_mask) in enumerate(faults): + next_probs: dict[int, list[float]] = {} + + # A. Expand the beam + for s, total, delta in beam: + entry = next_probs.get(s) + absent_total = total * q + absent_delta = delta * q + if entry is None: + next_probs[s] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + t = s ^ det_mask + present_total = total * p + present_delta = delta * delta_scale + if t == s: + entry = next_probs[s] + entry[0] += present_total + entry[1] += present_delta + else: + entry = next_probs.get(t) + if entry is None: + next_probs[t] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + + # B. Enforce Reality & Collapse the State Space + retiring_mask = retiring_masks[i] + if retiring_mask != 0: + collapsed_probs: dict[int, list[float]] = {} + expected_bits = actual_dets_mask & retiring_mask + keep_mask = ~retiring_mask + + for s, (total, delta) in next_probs.items(): + if (s & retiring_mask) != expected_bits: + continue + + shrunk_s = s & keep_mask + entry = collapsed_probs.get(shrunk_s) + if entry is None: + collapsed_probs[shrunk_s] = [total, delta] + else: + entry[0] += total + entry[1] += delta + else: + collapsed_probs = next_probs + + # C. Truncate the Beam (Top L Cutoff) + total_mass = sum(total for total, _ in collapsed_probs.values()) + if total_mass == 0.0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=discarded_mass, + max_width=max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + dropped_mass = 0.0 + if len(collapsed_probs) > L: + beam = heapq.nlargest( + L, + ( + (state, total, delta) + for state, (total, delta) in collapsed_probs.items() + ), + key=itemgetter(1), + ) + kept_mass = sum(total for _, total, _ in beam) + dropped_mass = total_mass - kept_mass + else: + beam = [ + (state, total, delta) + for state, (total, delta) in collapsed_probs.items() + ] + + inv_total_mass = 1.0 / total_mass + discarded_mass = (discarded_mass + dropped_mass) * inv_total_mass + beam = [ + (state, total * inv_total_mass, delta * inv_total_mass) + for state, total, delta in beam + ] + + # 6. Final Likelihood Comparison + _, _, final_delta = next((entry for entry in beam if entry[0] == 0), (0, 0.0, 0.0)) + margin = abs(final_delta) + certified = margin > discarded_mass + + if final_delta == 0.0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=margin, + discarded_mass=discarded_mass, + max_width=max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + return BeamDecodeResult( + predicted_logical=final_delta < 0.0, + certified=certified, + margin=margin, + discarded_mass=discarded_mass, + max_width=max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + +def run_experiment(circuit_fname: str, L: int, seed=None): + print(f"Running on circuit {circuit_fname}") + + circuit = stim.Circuit.from_file(circuit_fname) + + sampler = circuit.compile_detector_sampler(seed=seed) + syndromes, logicals = sampler.sample(shots=1, separate_observables=True) + + actual_dets = set(i for i, triggered in enumerate(syndromes[0]) if triggered) + actual_logical = logicals[0][0] + + result = decode_beam_search(circuit, actual_dets, L) + + print(f"Total Detectors: {circuit.num_detectors}") + print(f"Seed: {seed}") + print(f"Triggered Detectors: {len(actual_dets)}") + print(f"Width: {result.max_width}") + print(f"Predicted Logical: {result.predicted_logical}") + print(f"Actual Logical: {bool(actual_logical)}") + print(f"Certified: {result.certified}") + print(f"Margin: {result.margin:.6e}") + print(f"Discarded Mass: {result.discarded_mass:.6e}") + print(f"Elapsed Seconds: {result.elapsed_seconds:.6f}") + + if result.predicted_logical is None: + print("Result: DECODE FAILED (Tie or Beam too narrow)") + else: + print(f"Result: {'SUCCESS' if result.predicted_logical == actual_logical else 'LOGICAL ERROR'}") + print() + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run one-shot trellis beam decoding on a Stim circuit.") + parser.add_argument("--circuit", required=True, help="Path to the .stim circuit file.") + parser.add_argument("--beam", type=int, default=1000, help="Beam width cutoff.") + parser.add_argument("--seed", type=int, default=None, help="Sampler seed.") + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_args() + run_experiment(args.circuit, L=args.beam, seed=args.seed) diff --git a/src/py/astar/trellis_beam_detcost_ranked.py b/src/py/astar/trellis_beam_detcost_ranked.py new file mode 100644 index 0000000..be934ee --- /dev/null +++ b/src/py/astar/trellis_beam_detcost_ranked.py @@ -0,0 +1,777 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import math +import shutil +import sys +import tempfile +import time +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import stim + + +STIM_RESULT_FORMATS = ("01", "b8", "r8", "ptb64", "hits", "dets") +STIM_RESULT_FORMATS_HELP = "/".join(STIM_RESULT_FORMATS) + + +@dataclass(frozen=True) +class Fault: + q: float + p: float + delta_scale: float + det_mask: int + likelihood_cost: float + + +@dataclass(frozen=True) +class DecoderModel: + faults: tuple[Fault, ...] + retiring_masks: tuple[int, ...] + live_masks_after: tuple[int, ...] + future_detcost: tuple[tuple[float, ...], ...] + all_possible_dets_mask: int + max_width: int + + +@dataclass(frozen=True) +class BeamDecodeResult: + predicted_logical: bool | None + certified: bool + margin: float + discarded_mass: float + max_width: int + elapsed_seconds: float + + +@dataclass(frozen=True) +class DecodingShot: + det_mask: int + actual_logical: bool | None + + +@dataclass(frozen=True) +class ExperimentSummary: + predictions: list[bool | None] + num_certified: int + num_low_confidence: int + num_errors: int + num_truth_shots: int + num_scored_shots: int + total_elapsed: float + total_triggered: int + max_width_seen: int + + +def _likelihood_cost(probability: float) -> float: + if probability <= 0.0: + return math.inf + if probability >= 1.0: + return 0.0 + return -math.log(probability / (1.0 - probability)) + + +def _detectors_from_mask(mask: int) -> list[int]: + detectors: list[int] = [] + while mask: + low_bit = mask & -mask + detectors.append(low_bit.bit_length() - 1) + mask ^= low_bit + return detectors + + +def _mask_from_bool_row(row: np.ndarray) -> int: + mask = 0 + for index in np.flatnonzero(row): + mask |= 1 << int(index) + return mask + + +def _future_detcost_by_detector(faults: tuple[Fault, ...], num_detectors: int) -> tuple[tuple[float, ...], ...]: + future_detcost: list[list[float]] = [[math.inf] * num_detectors for _ in range(len(faults) + 1)] + next_row = future_detcost[-1] + for fault_index in range(len(faults) - 1, -1, -1): + row = next_row.copy() + fault = faults[fault_index] + det_count = fault.det_mask.bit_count() + if det_count: + ecost = fault.likelihood_cost / det_count + for det_id in _detectors_from_mask(fault.det_mask): + if ecost < row[det_id]: + row[det_id] = ecost + future_detcost[fault_index] = row + next_row = row + return tuple(tuple(row) for row in future_detcost) + + +def _build_decoder_model(circuit: stim.Circuit) -> DecoderModel: + dem = circuit.detector_error_model(decompose_errors=False).flattened() + + faults: list[Fault] = [] + all_possible_dets_mask = 0 + last_seen_index: dict[int, int] = {} + + for inst in dem: + if inst.type != "error": + continue + + p = float(inst.args_copy()[0]) + det_mask = 0 + flip_l0 = 0 + for target in inst.targets_copy(): + if target.is_separator(): + continue + if target.is_relative_detector_id(): + det_mask ^= 1 << target.val + elif target.is_logical_observable_id() and target.val == 0: + flip_l0 ^= 1 + + faults.append( + Fault( + q=1.0 - p, + p=p, + delta_scale=(-p if flip_l0 else p), + det_mask=det_mask, + likelihood_cost=_likelihood_cost(p), + ) + ) + all_possible_dets_mask |= det_mask + + for det_id in _detectors_from_mask(det_mask): + last_seen_index[det_id] = len(faults) - 1 + + retiring_masks = [0] * len(faults) + for det_id, index in last_seen_index.items(): + retiring_masks[index] |= 1 << det_id + + live_masks_after = [0] * (len(faults) + 1) + active_mask = 0 + max_width = 0 + for i, fault in enumerate(faults): + active_mask |= fault.det_mask + max_width = max(max_width, active_mask.bit_count()) + active_mask &= ~retiring_masks[i] + live_masks_after[i + 1] = active_mask + + frozen_faults = tuple(faults) + return DecoderModel( + faults=frozen_faults, + retiring_masks=tuple(retiring_masks), + live_masks_after=tuple(live_masks_after), + future_detcost=_future_detcost_by_detector(frozen_faults, circuit.num_detectors), + all_possible_dets_mask=all_possible_dets_mask, + max_width=max_width, + ) + + +def _detcost_penalty(mismatch_mask: int, future_detcost: tuple[float, ...]) -> float: + total = 0.0 + pending = mismatch_mask + + while pending: + low_bit = pending & -pending + detector = low_bit.bit_length() - 1 + pending ^= low_bit + + best = future_detcost[detector] + if best == math.inf: + return math.inf + total += best + + return total + + +def _as_bool_2d(data: np.ndarray, *, expected_cols: int, description: str) -> np.ndarray: + arr = np.asarray(data) + if arr.ndim != 2: + raise ValueError(f"Expected {description} to be a 2D array but got shape {arr.shape!r}.") + if arr.shape[1] != expected_cols: + raise ValueError( + f"Expected {description} to have {expected_cols} columns but got {arr.shape[1]}." + ) + if arr.dtype != np.bool_: + arr = arr.astype(np.bool_, copy=False) + return arr + + +def _sample_shot_arrays( + circuit: stim.Circuit, + *, + shots: int, + seed: int | None, +) -> tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets, obs = sampler.sample(shots=shots, separate_observables=True) + return ( + _as_bool_2d(dets, expected_cols=circuit.num_detectors, description="sampled detector data"), + _as_bool_2d(obs, expected_cols=circuit.num_observables, description="sampled observable data"), + ) + + +def _read_detector_shot_arrays( + *, + path: str, + fmt: str, + num_detectors: int, + num_observables: int, +) -> tuple[np.ndarray, np.ndarray | None]: + common_kwargs = dict( + path=path, + format=fmt, + bit_packed=False, + num_measurements=0, + num_detectors=num_detectors, + num_observables=num_observables, + ) + + if num_observables: + try: + dets, obs = stim.read_shot_data_file(**common_kwargs, separate_observables=True) + return ( + _as_bool_2d(dets, expected_cols=num_detectors, description="input detector data"), + _as_bool_2d(obs, expected_cols=num_observables, description="appended observable data"), + ) + except TypeError: + flat = stim.read_shot_data_file(**common_kwargs) + flat = _as_bool_2d( + flat, + expected_cols=num_detectors + num_observables, + description="combined detector/observable input data", + ) + return flat[:, :num_detectors], flat[:, num_detectors:] + + flat = stim.read_shot_data_file(**common_kwargs) + return _as_bool_2d(flat, expected_cols=num_detectors, description="input detector data"), None + + +def _read_observable_shot_array(*, path: str, fmt: str, num_observables: int) -> np.ndarray: + obs = stim.read_shot_data_file( + path=path, + format=fmt, + bit_packed=False, + num_measurements=0, + num_detectors=0, + num_observables=num_observables, + ) + return _as_bool_2d(obs, expected_cols=num_observables, description="observable input data") + + +def _apply_shot_range( + dets: np.ndarray, + obs: np.ndarray | None, + *, + shot_range_begin: int, + shot_range_end: int, +) -> tuple[np.ndarray, np.ndarray | None]: + if not (shot_range_begin or shot_range_end): + return dets, obs + + if shot_range_end < shot_range_begin: + raise ValueError("Provided shot range must satisfy --shot-range-end >= --shot-range-begin.") + if shot_range_end > len(dets): + raise ValueError( + f"Shot range end {shot_range_end} is past the end of the shot data (size {len(dets)})." + ) + + dets = dets[shot_range_begin:shot_range_end] + if obs is not None: + obs = obs[shot_range_begin:shot_range_end] + return dets, obs + + +def _shots_from_arrays(dets: np.ndarray, obs: np.ndarray | None) -> list[DecodingShot]: + shots: list[DecodingShot] = [] + for shot_index in range(dets.shape[0]): + actual_logical = None if obs is None else bool(obs[shot_index, 0]) + shots.append( + DecodingShot( + det_mask=_mask_from_bool_row(dets[shot_index]), + actual_logical=actual_logical, + ) + ) + return shots + + +def _resolve_stdin_path_if_needed(path: str, *, temp_dir: str, stem: str) -> str: + if path != "-": + return path + temp_path = str(Path(temp_dir) / f"{stem}.bin") + with open(temp_path, "wb") as f: + f.write(sys.stdin.buffer.read()) + return temp_path + + +def _resolve_stdout_path_if_needed(path: str, *, temp_dir: str, stem: str) -> tuple[str, bool]: + if path != "-": + return path, False + return str(Path(temp_dir) / f"{stem}.bin"), True + + +def _copy_file_to_stdout(path: str) -> None: + sys.stdout.flush() + with open(path, "rb") as f: + shutil.copyfileobj(f, sys.stdout.buffer) + sys.stdout.buffer.flush() + + +def _load_shots( + circuit: stim.Circuit, + args: argparse.Namespace, + *, + temp_dir: str, +) -> list[DecodingShot]: + if args.in_file: + in_path = _resolve_stdin_path_if_needed(args.in_file, temp_dir=temp_dir, stem="shots_in") + appended_obs_count = circuit.num_observables if args.in_includes_appended_observables else 0 + dets, obs = _read_detector_shot_arrays( + path=in_path, + fmt=args.in_format, + num_detectors=circuit.num_detectors, + num_observables=appended_obs_count, + ) + + if args.obs_in_file: + obs_in_path = _resolve_stdin_path_if_needed(args.obs_in_file, temp_dir=temp_dir, stem="obs_in") + obs = _read_observable_shot_array( + path=obs_in_path, + fmt=args.obs_in_format, + num_observables=circuit.num_observables, + ) + if len(obs) != len(dets): + raise ValueError("Observable input ended before, or after, the detector shot data.") + else: + dets, obs = _sample_shot_arrays(circuit, shots=args.sample_num_shots, seed=args.sample_seed) + + dets, obs = _apply_shot_range( + dets, + obs, + shot_range_begin=args.shot_range_begin, + shot_range_end=args.shot_range_end, + ) + return _shots_from_arrays(dets, obs) + + +def decode_beam_search_detcost_ranked( + model: DecoderModel, + actual_dets_mask: int, + L: int, +) -> BeamDecodeResult: + start_time = time.perf_counter() + + if (actual_dets_mask & ~model.all_possible_dets_mask) != 0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=0.0, + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + beam = [(0, 1.0, 1.0)] + discarded_mass = 0.0 + + for i, fault in enumerate(model.faults): + collapsed_probs: dict[int, list[float]] = {} + total_mass = 0.0 + retiring_mask = model.retiring_masks[i] + + if retiring_mask == 0: + for state, total, delta in beam: + absent_total = total * fault.q + absent_delta = delta * fault.q + total_mass += absent_total + entry = collapsed_probs.get(state) + if entry is None: + collapsed_probs[state] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + toggled = state ^ fault.det_mask + present_total = total * fault.p + present_delta = delta * fault.delta_scale + total_mass += present_total + entry = collapsed_probs.get(toggled) + if entry is None: + collapsed_probs[toggled] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + else: + expected_bits = actual_dets_mask & retiring_mask + keep_mask = ~retiring_mask + for state, total, delta in beam: + absent_total = total * fault.q + absent_delta = delta * fault.q + if (state & retiring_mask) == expected_bits: + shrunk = state & keep_mask + total_mass += absent_total + entry = collapsed_probs.get(shrunk) + if entry is None: + collapsed_probs[shrunk] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + toggled = state ^ fault.det_mask + present_total = total * fault.p + present_delta = delta * fault.delta_scale + if (toggled & retiring_mask) == expected_bits: + shrunk = toggled & keep_mask + total_mass += present_total + entry = collapsed_probs.get(shrunk) + if entry is None: + collapsed_probs[shrunk] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + + if total_mass == 0.0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=discarded_mass, + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + ranked_states: list[tuple[float, float, int, float]] = [] + live_target_mask = actual_dets_mask & model.live_masks_after[i + 1] + next_future_detcost = model.future_detcost[i + 1] + for state, (total, delta) in collapsed_probs.items(): + mismatch_mask = state ^ live_target_mask + penalty = _detcost_penalty(mismatch_mask=mismatch_mask, future_detcost=next_future_detcost) + if penalty == math.inf: + rank_score = -math.inf + else: + rank_score = math.log(total) - penalty + ranked_states.append((rank_score, total, state, delta)) + + dropped_mass = 0.0 + if len(ranked_states) > L: + ranked_states.sort(reverse=True) + kept = ranked_states[:L] + beam = [(state, total, delta) for _, total, state, delta in kept] + kept_mass = sum(total for _, total, _, _ in kept) + dropped_mass = total_mass - kept_mass + else: + beam = [(state, total, delta) for _, total, state, delta in ranked_states] + + inv_total_mass = 1.0 / total_mass + discarded_mass = (discarded_mass + dropped_mass) * inv_total_mass + beam = [ + (state, total * inv_total_mass, delta * inv_total_mass) + for state, total, delta in beam + ] + + _, _, final_delta = next((entry for entry in beam if entry[0] == 0), (0, 0.0, 0.0)) + margin = abs(final_delta) + certified = margin > discarded_mass + + if final_delta == 0.0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=margin, + discarded_mass=discarded_mass, + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + return BeamDecodeResult( + predicted_logical=final_delta < 0.0, + certified=certified, + margin=margin, + discarded_mass=discarded_mass, + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + +def _print_run_header( + *, + circuit: stim.Circuit, + args: argparse.Namespace, + num_shots: int, + log_stream, +) -> None: + print(f"Running on circuit {args.circuit}", file=log_stream) + print(f"Total Detectors: {circuit.num_detectors}", file=log_stream) + print(f"Total Observables: {circuit.num_observables}", file=log_stream) + if args.in_file: + print(f"Shot Input: {args.in_file}", file=log_stream) + print(f"Shot Input Format: {args.in_format}", file=log_stream) + if args.in_includes_appended_observables: + print("Observable Input: appended to --in", file=log_stream) + elif args.obs_in_file: + print(f"Observable Input: {args.obs_in_file}", file=log_stream) + print(f"Observable Format: {args.obs_in_format}", file=log_stream) + else: + print("Observable Input: none", file=log_stream) + else: + print(f"Sample Seed: {args.sample_seed}", file=log_stream) + print(f"Requested Shots: {args.sample_num_shots}", file=log_stream) + if args.shot_range_begin or args.shot_range_end: + print( + f"Shot Range: [{args.shot_range_begin}, {args.shot_range_end})", + file=log_stream, + ) + print(f"Num Shots: {num_shots}", file=log_stream) + + +def run_experiment(args: argparse.Namespace) -> ExperimentSummary: + circuit = stim.Circuit.from_file(args.circuit) + if circuit.num_observables != 1: + raise ValueError( + "This decoder currently supports exactly one logical observable, because it only tracks L0. " + f"The circuit has {circuit.num_observables} observables." + ) + + model = _build_decoder_model(circuit) + log_stream = sys.stderr if args.out_file == "-" else sys.stdout + + with tempfile.TemporaryDirectory() as temp_dir: + shots = _load_shots(circuit, args, temp_dir=temp_dir) + _print_run_header(circuit=circuit, args=args, num_shots=len(shots), log_stream=log_stream) + + num_errors = 0 + num_low_confidence = 0 + num_certified = 0 + num_truth_shots = 0 + num_scored_shots = 0 + total_elapsed = 0.0 + total_triggered = 0 + max_width_seen = 0 + predictions: list[bool | None] = [] + + for shot_index, shot in enumerate(shots): + result = decode_beam_search_detcost_ranked(model, shot.det_mask, args.beam) + predictions.append(result.predicted_logical) + + success: bool | None + if shot.actual_logical is None or result.predicted_logical is None: + success = None + else: + success = result.predicted_logical == shot.actual_logical + + if result.predicted_logical is None: + num_low_confidence += 1 + if shot.actual_logical is not None: + num_truth_shots += 1 + if success is not None: + num_scored_shots += 1 + if not success: + num_errors += 1 + if result.certified: + num_certified += 1 + + total_elapsed += result.elapsed_seconds + triggered_dets = shot.det_mask.bit_count() + total_triggered += triggered_dets + max_width_seen = max(max_width_seen, result.max_width) + + shots_done = shot_index + 1 + error_rate_so_far = num_errors / num_scored_shots if num_scored_shots else 0.0 + print( + f"progress shots_done={shots_done}/{len(shots)} errors_so_far={num_errors} " + f"low_conf_so_far={num_low_confidence} scored_shots_so_far={num_scored_shots} " + f"error_rate_so_far={error_rate_so_far:.6f} elapsed_total_seconds={total_elapsed:.6f}", + file=log_stream, + ) + + if args.print_per_shot: + print( + f"shot={shot_index} triggered_detectors={triggered_dets} " + f"predicted_logical={result.predicted_logical} actual_logical={shot.actual_logical} " + f"success={success} certified={result.certified} " + f"margin={result.margin:.6e} discarded_mass={result.discarded_mass:.6e} " + f"elapsed_seconds={result.elapsed_seconds:.6f}", + file=log_stream, + ) + + if args.out_file: + output_path, copy_to_stdout = _resolve_stdout_path_if_needed( + args.out_file, + temp_dir=temp_dir, + stem="predictions_out", + ) + prediction_data = np.zeros((len(predictions), circuit.num_observables), dtype=np.bool_) + for shot_index, predicted_logical in enumerate(predictions): + prediction_data[shot_index, 0] = bool(predicted_logical) if predicted_logical is not None else False + + if args.out_format == "ptb64" and len(prediction_data) % 64 != 0: + raise ValueError("The ptb64 format requires the number of shots to be a multiple of 64.") + + stim.write_shot_data_file( + data=prediction_data, + path=output_path, + format=args.out_format, + num_measurements=0, + num_detectors=0, + num_observables=circuit.num_observables, + ) + if copy_to_stdout: + _copy_file_to_stdout(output_path) + if num_low_confidence: + print( + f"warning: wrote {num_low_confidence} low-confidence predictions as L0=0 because Stim result " + "files can only store bits, not unknown values.", + file=log_stream, + ) + + print(f"Beam: {args.beam}", file=log_stream) + print(f"Mean Triggered Dets: {total_triggered / max(1, len(shots)):.2f}", file=log_stream) + print(f"Max Width: {max_width_seen}", file=log_stream) + print(f"Certified Shots: {num_certified}", file=log_stream) + print(f"Low Confidence: {num_low_confidence}", file=log_stream) + print(f"Truth-Labeled Shots: {num_truth_shots}", file=log_stream) + print(f"Scored Shots: {num_scored_shots}", file=log_stream) + if num_truth_shots: + print(f"Logical Errors: {num_errors}", file=log_stream) + else: + print("Logical Errors: n/a", file=log_stream) + print(f"Total Seconds: {total_elapsed:.6f}", file=log_stream) + print(f"Mean Seconds/Shot: {total_elapsed / max(1, len(shots)):.6f}", file=log_stream) + + return ExperimentSummary( + predictions=predictions, + num_certified=num_certified, + num_low_confidence=num_low_confidence, + num_errors=num_errors, + num_truth_shots=num_truth_shots, + num_scored_shots=num_scored_shots, + total_elapsed=total_elapsed, + total_triggered=total_triggered, + max_width_seen=max_width_seen, + ) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Run trellis beam decoding ranked by mass minus a detcost-style future penalty, " + "with Stim-compatible shot-data I/O options." + ), + allow_abbrev=False, + ) + parser.add_argument("--circuit", required=True, help="Path to the .stim circuit file.") + parser.add_argument("--beam", type=int, default=1000, help="Beam width cutoff.") + parser.add_argument( + "--sample-num-shots", + type=int, + default=None, + help="Number of sampled shots. Defaults to 1 unless --in is provided.", + ) + parser.add_argument("--sample-seed", type=int, default=None, help="Stim sampler seed.") + parser.add_argument( + "--shot-range-begin", + type=int, + default=0, + help=( + "If both --shot-range-begin and --shot-range-end are 0, decode all available shots. " + "Otherwise only decode shots in [begin, end)." + ), + ) + parser.add_argument( + "--shot-range-end", + type=int, + default=0, + help=( + "If both --shot-range-begin and --shot-range-end are 0, decode all available shots. " + "Otherwise only decode shots in [begin, end)." + ), + ) + parser.add_argument( + "--in", + dest="in_file", + default="", + help="File to read detection events from (use - for stdin).", + ) + parser.add_argument( + "--in-format", + "--in_format", + dest="in_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file read by --in ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--in-includes-appended-observables", + "--in_includes_appended_observables", + dest="in_includes_appended_observables", + action="store_true", + help="Assume the observable flips are appended to each shot in --in.", + ) + parser.add_argument( + "--obs-in", + "--obs_in", + dest="obs_in_file", + default="", + help="File to read observable flips from (use - for stdin).", + ) + parser.add_argument( + "--obs-in-format", + "--obs_in_format", + dest="obs_in_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file read by --obs-in ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--out", + dest="out_file", + default="", + help="File to write predicted observable flips to (use - for stdout).", + ) + parser.add_argument( + "--out-format", + "--out_format", + dest="out_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file written by --out ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--print-per-shot", + action="store_true", + help="Print a detailed line per decoded shot.", + ) + args = parser.parse_args() + + if args.sample_num_shots is None: + # Preserve the original script's one-shot default while still allowing + # file input without requiring --sample-num-shots 0. + args.sample_num_shots = 0 if args.in_file else 1 + + if args.beam <= 0: + raise ValueError("--beam must be positive.") + if args.sample_num_shots < 0: + raise ValueError("--sample-num-shots must be non-negative.") + if args.sample_seed is not None and args.sample_seed < 0: + raise ValueError("--sample-seed must be non-negative.") + if args.shot_range_begin < 0 or args.shot_range_end < 0: + raise ValueError("--shot-range-begin and --shot-range-end must be non-negative.") + if args.shot_range_end < args.shot_range_begin: + raise ValueError("Provided shot range must satisfy --shot-range-end >= --shot-range-begin.") + if args.in_includes_appended_observables and args.obs_in_file: + raise ValueError( + "Choose either --in-includes-appended-observables or --obs-in, not both." + ) + if args.obs_in_file and not args.in_file: + raise ValueError("Cannot load observable flips from --obs-in without also providing --in.") + if args.in_file == "-" and args.obs_in_file == "-": + raise ValueError("At most one of --in and --obs-in may read from stdin.") + + num_shot_sources = int(args.sample_num_shots > 0) + int(bool(args.in_file)) + if num_shot_sources != 1: + raise ValueError("Requires exactly one source of shots: either --sample-num-shots > 0 or --in.") + + return args + + +if __name__ == "__main__": + run_experiment(_parse_args()) diff --git a/src/py/astar/trellis_beam_detcost_ranked_threshold.py b/src/py/astar/trellis_beam_detcost_ranked_threshold.py new file mode 100644 index 0000000..9e2afc4 --- /dev/null +++ b/src/py/astar/trellis_beam_detcost_ranked_threshold.py @@ -0,0 +1,1100 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import math +import shutil +import sys +import tempfile +import time +from dataclasses import dataclass, field +from pathlib import Path + +import numpy as np +import stim + + +STIM_RESULT_FORMATS = ("01", "b8", "r8", "ptb64", "hits", "dets") +STIM_RESULT_FORMATS_HELP = "/".join(STIM_RESULT_FORMATS) +BEAM_PRUNE_MODES = ("fixed", "delta", "mass") +BEAM_PRUNE_MODES_HELP = "/".join(BEAM_PRUNE_MODES) + + +@dataclass(frozen=True) +class Fault: + q: float + p: float + delta_scale: float + det_mask: int + likelihood_cost: float + + +@dataclass(frozen=True) +class DecoderModel: + faults: tuple[Fault, ...] + retiring_masks: tuple[int, ...] + live_masks_after: tuple[int, ...] + future_detcost: tuple[tuple[float, ...], ...] + all_possible_dets_mask: int + max_width: int + + +@dataclass(frozen=True) +class BeamPruningConfig: + mode: str + beam_width: int + score_delta: float | None + mass_epsilon: float | None + hard_cap: int | None + + +@dataclass(frozen=True) +class BeamDecodeResult: + predicted_logical: bool | None + certified: bool + margin: float + discarded_mass: float + kept_state_counts: tuple[int, ...] + max_width: int + elapsed_seconds: float + + +@dataclass(frozen=True) +class DecodingShot: + det_mask: int + actual_logical: bool | None + + +@dataclass(frozen=True) +class IntegerSeriesSummary: + count: int + minimum: int | None + median: float | None + mean: float | None + maximum: int | None + + +@dataclass +class IntegerHistogramAccumulator: + count: int = 0 + total: int = 0 + minimum: int | None = None + maximum: int | None = None + histogram: dict[int, int] = field(default_factory=dict) + + def add(self, value: int) -> None: + self.count += 1 + self.total += value + if self.minimum is None or value < self.minimum: + self.minimum = value + if self.maximum is None or value > self.maximum: + self.maximum = value + self.histogram[value] = self.histogram.get(value, 0) + 1 + + def add_many(self, values: tuple[int, ...] | list[int]) -> None: + for value in values: + self.add(value) + + def summary(self) -> IntegerSeriesSummary: + if self.count == 0: + return IntegerSeriesSummary( + count=0, + minimum=None, + median=None, + mean=None, + maximum=None, + ) + + lower_target = (self.count - 1) // 2 + upper_target = self.count // 2 + seen = 0 + lower_value: int | None = None + upper_value: int | None = None + for value in sorted(self.histogram): + seen += self.histogram[value] + if lower_value is None and seen > lower_target: + lower_value = value + if upper_value is None and seen > upper_target: + upper_value = value + break + + assert lower_value is not None and upper_value is not None + return IntegerSeriesSummary( + count=self.count, + minimum=self.minimum, + median=(lower_value + upper_value) / 2.0, + mean=self.total / self.count, + maximum=self.maximum, + ) + + +@dataclass(frozen=True) +class ExperimentSummary: + predictions: list[bool | None] + num_certified: int + num_low_confidence: int + num_errors: int + num_truth_shots: int + num_scored_shots: int + total_elapsed: float + total_triggered: int + max_width_seen: int + kept_state_summary: IntegerSeriesSummary + + +def _likelihood_cost(probability: float) -> float: + if probability <= 0.0: + return math.inf + if probability >= 1.0: + return 0.0 + return -math.log(probability / (1.0 - probability)) + + +def _detectors_from_mask(mask: int) -> list[int]: + detectors: list[int] = [] + while mask: + low_bit = mask & -mask + detectors.append(low_bit.bit_length() - 1) + mask ^= low_bit + return detectors + + +def _mask_from_bool_row(row: np.ndarray) -> int: + mask = 0 + for index in np.flatnonzero(row): + mask |= 1 << int(index) + return mask + + +def _future_detcost_by_detector(faults: tuple[Fault, ...], num_detectors: int) -> tuple[tuple[float, ...], ...]: + future_detcost: list[list[float]] = [[math.inf] * num_detectors for _ in range(len(faults) + 1)] + next_row = future_detcost[-1] + for fault_index in range(len(faults) - 1, -1, -1): + row = next_row.copy() + fault = faults[fault_index] + det_count = fault.det_mask.bit_count() + if det_count: + ecost = fault.likelihood_cost / det_count + for det_id in _detectors_from_mask(fault.det_mask): + if ecost < row[det_id]: + row[det_id] = ecost + future_detcost[fault_index] = row + next_row = row + return tuple(tuple(row) for row in future_detcost) + + +def _build_decoder_model(circuit: stim.Circuit) -> DecoderModel: + dem = circuit.detector_error_model(decompose_errors=False).flattened() + + faults: list[Fault] = [] + all_possible_dets_mask = 0 + last_seen_index: dict[int, int] = {} + + for inst in dem: + if inst.type != "error": + continue + + p = float(inst.args_copy()[0]) + det_mask = 0 + flip_l0 = 0 + for target in inst.targets_copy(): + if target.is_separator(): + continue + if target.is_relative_detector_id(): + det_mask ^= 1 << target.val + elif target.is_logical_observable_id() and target.val == 0: + flip_l0 ^= 1 + + faults.append( + Fault( + q=1.0 - p, + p=p, + delta_scale=(-p if flip_l0 else p), + det_mask=det_mask, + likelihood_cost=_likelihood_cost(p), + ) + ) + all_possible_dets_mask |= det_mask + + for det_id in _detectors_from_mask(det_mask): + last_seen_index[det_id] = len(faults) - 1 + + retiring_masks = [0] * len(faults) + for det_id, index in last_seen_index.items(): + retiring_masks[index] |= 1 << det_id + + live_masks_after = [0] * (len(faults) + 1) + active_mask = 0 + max_width = 0 + for i, fault in enumerate(faults): + active_mask |= fault.det_mask + max_width = max(max_width, active_mask.bit_count()) + active_mask &= ~retiring_masks[i] + live_masks_after[i + 1] = active_mask + + frozen_faults = tuple(faults) + return DecoderModel( + faults=frozen_faults, + retiring_masks=tuple(retiring_masks), + live_masks_after=tuple(live_masks_after), + future_detcost=_future_detcost_by_detector(frozen_faults, circuit.num_detectors), + all_possible_dets_mask=all_possible_dets_mask, + max_width=max_width, + ) + + +def _detcost_penalty(mismatch_mask: int, future_detcost: tuple[float, ...]) -> float: + total = 0.0 + pending = mismatch_mask + + while pending: + low_bit = pending & -pending + detector = low_bit.bit_length() - 1 + pending ^= low_bit + + best = future_detcost[detector] + if best == math.inf: + return math.inf + total += best + + return total + + +def _accumulate_collapsed_state( + collapsed_probs: dict[int, list[float]], + *, + state: int, + total: float, + delta: float, +) -> float: + if total <= 0.0: + return 0.0 + + entry = collapsed_probs.get(state) + if entry is None: + collapsed_probs[state] = [total, delta] + else: + entry[0] += total + entry[1] += delta + return total + + +def _prune_ranked_states( + ranked_states: list[tuple[float, float, int, float]], + *, + total_mass: float, + pruning: BeamPruningConfig, +) -> tuple[list[tuple[int, float, float]], float]: + if not ranked_states: + return [], total_mass + + ranked_states.sort(reverse=True) + + if pruning.mode == "fixed": + kept_ranked = ranked_states[:pruning.beam_width] + elif pruning.mode == "delta": + assert pruning.score_delta is not None + best_score = ranked_states[0][0] + if best_score == -math.inf: + kept_ranked = ranked_states + else: + cutoff = best_score - pruning.score_delta + kept_ranked = [entry for entry in ranked_states if entry[0] >= cutoff] + if not kept_ranked: + kept_ranked = ranked_states[:1] + if pruning.hard_cap is not None and len(kept_ranked) > pruning.hard_cap: + kept_ranked = kept_ranked[:pruning.hard_cap] + elif pruning.mode == "mass": + assert pruning.mass_epsilon is not None + retained_target_mass = (1.0 - pruning.mass_epsilon) * total_mass + retained_mass = 0.0 + kept_ranked = [] + for entry in ranked_states: + kept_ranked.append(entry) + retained_mass += entry[1] + if retained_mass >= retained_target_mass: + break + if not kept_ranked: + kept_ranked = ranked_states[:1] + if pruning.hard_cap is not None and len(kept_ranked) > pruning.hard_cap: + kept_ranked = kept_ranked[:pruning.hard_cap] + else: + raise ValueError(f"Unsupported pruning mode: {pruning.mode!r}") + + kept_mass = sum(total for _, total, _, _ in kept_ranked) + dropped_mass = total_mass - kept_mass + kept_beam = [(state, total, delta) for _, total, state, delta in kept_ranked] + return kept_beam, dropped_mass + + +def _summarize_int_values(values: tuple[int, ...] | list[int]) -> IntegerSeriesSummary: + if not values: + return IntegerSeriesSummary( + count=0, + minimum=None, + median=None, + mean=None, + maximum=None, + ) + + sorted_values = sorted(values) + count = len(sorted_values) + lower = sorted_values[(count - 1) // 2] + upper = sorted_values[count // 2] + return IntegerSeriesSummary( + count=count, + minimum=sorted_values[0], + median=(lower + upper) / 2.0, + mean=sum(sorted_values) / count, + maximum=sorted_values[-1], + ) + + +def _format_optional_int(value: int | None) -> str: + return "none" if value is None else str(value) + + +def _format_pruning_value(value: float | None) -> str: + if value is None: + return "n/a" + return f"{value:.6g}" + + +def _format_summary_int(value: int | None) -> str: + return "n/a" if value is None else str(value) + + +def _format_summary_float(value: float | None, *, digits: int = 2) -> str: + return "n/a" if value is None else f"{value:.{digits}f}" + + +def _print_pruning_configuration(*, pruning: BeamPruningConfig, log_stream) -> None: + print(f"Beam Prune Mode: {pruning.mode}", file=log_stream) + if pruning.mode == "fixed": + print(f"Beam Width: {pruning.beam_width}", file=log_stream) + elif pruning.mode == "delta": + print(f"Beam Score Delta: {_format_pruning_value(pruning.score_delta)}", file=log_stream) + print(f"Beam Hard Cap: {_format_optional_int(pruning.hard_cap)}", file=log_stream) + elif pruning.mode == "mass": + assert pruning.mass_epsilon is not None + print(f"Beam Mass Epsilon: {_format_pruning_value(pruning.mass_epsilon)}", file=log_stream) + print(f"Beam Retained Mass: {_format_pruning_value(1.0 - pruning.mass_epsilon)}", file=log_stream) + print(f"Beam Hard Cap: {_format_optional_int(pruning.hard_cap)}", file=log_stream) + else: + raise ValueError(f"Unsupported pruning mode: {pruning.mode!r}") + + +def _beam_pruning_config_from_args(args: argparse.Namespace) -> BeamPruningConfig: + return BeamPruningConfig( + mode=args.beam_prune_mode, + beam_width=args.beam, + score_delta=args.beam_score_delta, + mass_epsilon=args.beam_mass_epsilon, + hard_cap=args.beam_hard_cap, + ) + + +def _as_bool_2d(data: np.ndarray, *, expected_cols: int, description: str) -> np.ndarray: + arr = np.asarray(data) + if arr.ndim != 2: + raise ValueError(f"Expected {description} to be a 2D array but got shape {arr.shape!r}.") + if arr.shape[1] != expected_cols: + raise ValueError( + f"Expected {description} to have {expected_cols} columns but got {arr.shape[1]}." + ) + if arr.dtype != np.bool_: + arr = arr.astype(np.bool_, copy=False) + return arr + + +def _sample_shot_arrays( + circuit: stim.Circuit, + *, + shots: int, + seed: int | None, +) -> tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets, obs = sampler.sample(shots=shots, separate_observables=True) + return ( + _as_bool_2d(dets, expected_cols=circuit.num_detectors, description="sampled detector data"), + _as_bool_2d(obs, expected_cols=circuit.num_observables, description="sampled observable data"), + ) + + +def _read_detector_shot_arrays( + *, + path: str, + fmt: str, + num_detectors: int, + num_observables: int, +) -> tuple[np.ndarray, np.ndarray | None]: + common_kwargs = dict( + path=path, + format=fmt, + bit_packed=False, + num_measurements=0, + num_detectors=num_detectors, + num_observables=num_observables, + ) + + if num_observables: + try: + dets, obs = stim.read_shot_data_file(**common_kwargs, separate_observables=True) + return ( + _as_bool_2d(dets, expected_cols=num_detectors, description="input detector data"), + _as_bool_2d(obs, expected_cols=num_observables, description="appended observable data"), + ) + except TypeError: + flat = stim.read_shot_data_file(**common_kwargs) + flat = _as_bool_2d( + flat, + expected_cols=num_detectors + num_observables, + description="combined detector/observable input data", + ) + return flat[:, :num_detectors], flat[:, num_detectors:] + + flat = stim.read_shot_data_file(**common_kwargs) + return _as_bool_2d(flat, expected_cols=num_detectors, description="input detector data"), None + + +def _read_observable_shot_array(*, path: str, fmt: str, num_observables: int) -> np.ndarray: + obs = stim.read_shot_data_file( + path=path, + format=fmt, + bit_packed=False, + num_measurements=0, + num_detectors=0, + num_observables=num_observables, + ) + return _as_bool_2d(obs, expected_cols=num_observables, description="observable input data") + + +def _apply_shot_range( + dets: np.ndarray, + obs: np.ndarray | None, + *, + shot_range_begin: int, + shot_range_end: int, +) -> tuple[np.ndarray, np.ndarray | None]: + if not (shot_range_begin or shot_range_end): + return dets, obs + + if shot_range_end < shot_range_begin: + raise ValueError("Provided shot range must satisfy --shot-range-end >= --shot-range-begin.") + if shot_range_end > len(dets): + raise ValueError( + f"Shot range end {shot_range_end} is past the end of the shot data (size {len(dets)})." + ) + + dets = dets[shot_range_begin:shot_range_end] + if obs is not None: + obs = obs[shot_range_begin:shot_range_end] + return dets, obs + + +def _shots_from_arrays(dets: np.ndarray, obs: np.ndarray | None) -> list[DecodingShot]: + shots: list[DecodingShot] = [] + for shot_index in range(dets.shape[0]): + actual_logical = None if obs is None else bool(obs[shot_index, 0]) + shots.append( + DecodingShot( + det_mask=_mask_from_bool_row(dets[shot_index]), + actual_logical=actual_logical, + ) + ) + return shots + + +def _resolve_stdin_path_if_needed(path: str, *, temp_dir: str, stem: str) -> str: + if path != "-": + return path + temp_path = str(Path(temp_dir) / f"{stem}.bin") + with open(temp_path, "wb") as f: + f.write(sys.stdin.buffer.read()) + return temp_path + + +def _resolve_stdout_path_if_needed(path: str, *, temp_dir: str, stem: str) -> tuple[str, bool]: + if path != "-": + return path, False + return str(Path(temp_dir) / f"{stem}.bin"), True + + +def _copy_file_to_stdout(path: str) -> None: + sys.stdout.flush() + with open(path, "rb") as f: + shutil.copyfileobj(f, sys.stdout.buffer) + sys.stdout.buffer.flush() + + +def _load_shots( + circuit: stim.Circuit, + args: argparse.Namespace, + *, + temp_dir: str, +) -> list[DecodingShot]: + if args.in_file: + in_path = _resolve_stdin_path_if_needed(args.in_file, temp_dir=temp_dir, stem="shots_in") + appended_obs_count = circuit.num_observables if args.in_includes_appended_observables else 0 + dets, obs = _read_detector_shot_arrays( + path=in_path, + fmt=args.in_format, + num_detectors=circuit.num_detectors, + num_observables=appended_obs_count, + ) + + if args.obs_in_file: + obs_in_path = _resolve_stdin_path_if_needed(args.obs_in_file, temp_dir=temp_dir, stem="obs_in") + obs = _read_observable_shot_array( + path=obs_in_path, + fmt=args.obs_in_format, + num_observables=circuit.num_observables, + ) + if len(obs) != len(dets): + raise ValueError("Observable input ended before, or after, the detector shot data.") + else: + dets, obs = _sample_shot_arrays(circuit, shots=args.sample_num_shots, seed=args.sample_seed) + + dets, obs = _apply_shot_range( + dets, + obs, + shot_range_begin=args.shot_range_begin, + shot_range_end=args.shot_range_end, + ) + return _shots_from_arrays(dets, obs) + + +def decode_beam_search_detcost_ranked( + model: DecoderModel, + actual_dets_mask: int, + pruning: BeamPruningConfig, +) -> BeamDecodeResult: + start_time = time.perf_counter() + retained_state_counts: list[int] = [] + + if (actual_dets_mask & ~model.all_possible_dets_mask) != 0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=0.0, + kept_state_counts=(0,) * len(model.faults), + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + beam = [(0, 1.0, 1.0)] + discarded_mass = 0.0 + + for i, fault in enumerate(model.faults): + collapsed_probs: dict[int, list[float]] = {} + total_mass = 0.0 + retiring_mask = model.retiring_masks[i] + + if retiring_mask == 0: + for state, total, delta in beam: + total_mass += _accumulate_collapsed_state( + collapsed_probs, + state=state, + total=total * fault.q, + delta=delta * fault.q, + ) + + total_mass += _accumulate_collapsed_state( + collapsed_probs, + state=state ^ fault.det_mask, + total=total * fault.p, + delta=delta * fault.delta_scale, + ) + else: + expected_bits = actual_dets_mask & retiring_mask + keep_mask = ~retiring_mask + for state, total, delta in beam: + absent_total = total * fault.q + if absent_total > 0.0 and (state & retiring_mask) == expected_bits: + total_mass += _accumulate_collapsed_state( + collapsed_probs, + state=state & keep_mask, + total=absent_total, + delta=delta * fault.q, + ) + + toggled = state ^ fault.det_mask + present_total = total * fault.p + if present_total > 0.0 and (toggled & retiring_mask) == expected_bits: + total_mass += _accumulate_collapsed_state( + collapsed_probs, + state=toggled & keep_mask, + total=present_total, + delta=delta * fault.delta_scale, + ) + + if total_mass == 0.0: + retained_state_counts.append(0) + retained_state_counts.extend([0] * (len(model.faults) - i - 1)) + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=discarded_mass, + kept_state_counts=tuple(retained_state_counts), + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + ranked_states: list[tuple[float, float, int, float]] = [] + live_target_mask = actual_dets_mask & model.live_masks_after[i + 1] + next_future_detcost = model.future_detcost[i + 1] + for state, (total, delta) in collapsed_probs.items(): + if total <= 0.0: + continue + mismatch_mask = state ^ live_target_mask + penalty = _detcost_penalty(mismatch_mask=mismatch_mask, future_detcost=next_future_detcost) + if penalty == math.inf: + rank_score = -math.inf + else: + rank_score = math.log(total) - penalty + ranked_states.append((rank_score, total, state, delta)) + + beam, dropped_mass = _prune_ranked_states( + ranked_states, + total_mass=total_mass, + pruning=pruning, + ) + retained_state_counts.append(len(beam)) + + inv_total_mass = 1.0 / total_mass + discarded_mass = (discarded_mass + dropped_mass) * inv_total_mass + beam = [ + (state, total * inv_total_mass, delta * inv_total_mass) + for state, total, delta in beam + ] + + _, _, final_delta = next((entry for entry in beam if entry[0] == 0), (0, 0.0, 0.0)) + margin = abs(final_delta) + certified = margin > discarded_mass + + if final_delta == 0.0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=margin, + discarded_mass=discarded_mass, + kept_state_counts=tuple(retained_state_counts), + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + return BeamDecodeResult( + predicted_logical=final_delta < 0.0, + certified=certified, + margin=margin, + discarded_mass=discarded_mass, + kept_state_counts=tuple(retained_state_counts), + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + +def _print_run_header( + *, + circuit: stim.Circuit, + args: argparse.Namespace, + pruning: BeamPruningConfig, + num_faults: int, + num_shots: int, + log_stream, +) -> None: + print(f"Running on circuit {args.circuit}", file=log_stream) + print(f"Total Detectors: {circuit.num_detectors}", file=log_stream) + print(f"Total Observables: {circuit.num_observables}", file=log_stream) + print(f"Total Faults: {num_faults}", file=log_stream) + _print_pruning_configuration(pruning=pruning, log_stream=log_stream) + if args.in_file: + print(f"Shot Input: {args.in_file}", file=log_stream) + print(f"Shot Input Format: {args.in_format}", file=log_stream) + if args.in_includes_appended_observables: + print("Observable Input: appended to --in", file=log_stream) + elif args.obs_in_file: + print(f"Observable Input: {args.obs_in_file}", file=log_stream) + print(f"Observable Format: {args.obs_in_format}", file=log_stream) + else: + print("Observable Input: none", file=log_stream) + else: + print(f"Sample Seed: {args.sample_seed}", file=log_stream) + print(f"Requested Shots: {args.sample_num_shots}", file=log_stream) + if args.shot_range_begin or args.shot_range_end: + print( + f"Shot Range: [{args.shot_range_begin}, {args.shot_range_end})", + file=log_stream, + ) + print(f"Num Shots: {num_shots}", file=log_stream) + + +def run_experiment(args: argparse.Namespace) -> ExperimentSummary: + circuit = stim.Circuit.from_file(args.circuit) + if circuit.num_observables != 1: + raise ValueError( + "This decoder currently supports exactly one logical observable, because it only tracks L0. " + f"The circuit has {circuit.num_observables} observables." + ) + + model = _build_decoder_model(circuit) + pruning = _beam_pruning_config_from_args(args) + log_stream = sys.stderr if args.out_file == "-" else sys.stdout + + with tempfile.TemporaryDirectory() as temp_dir: + shots = _load_shots(circuit, args, temp_dir=temp_dir) + _print_run_header( + circuit=circuit, + args=args, + pruning=pruning, + num_faults=len(model.faults), + num_shots=len(shots), + log_stream=log_stream, + ) + + num_errors = 0 + num_low_confidence = 0 + num_certified = 0 + num_truth_shots = 0 + num_scored_shots = 0 + total_elapsed = 0.0 + total_triggered = 0 + max_width_seen = 0 + predictions: list[bool | None] = [] + kept_state_accumulator = IntegerHistogramAccumulator() + + for shot_index, shot in enumerate(shots): + result = decode_beam_search_detcost_ranked(model, shot.det_mask, pruning) + predictions.append(result.predicted_logical) + kept_state_accumulator.add_many(result.kept_state_counts) + kept_state_summary = _summarize_int_values(result.kept_state_counts) + + success: bool | None + if shot.actual_logical is None or result.predicted_logical is None: + success = None + else: + success = result.predicted_logical == shot.actual_logical + + if result.predicted_logical is None: + num_low_confidence += 1 + if shot.actual_logical is not None: + num_truth_shots += 1 + if success is not None: + num_scored_shots += 1 + if not success: + num_errors += 1 + if result.certified: + num_certified += 1 + + total_elapsed += result.elapsed_seconds + triggered_dets = shot.det_mask.bit_count() + total_triggered += triggered_dets + max_width_seen = max(max_width_seen, result.max_width) + + shots_done = shot_index + 1 + error_rate_so_far = num_errors / num_scored_shots if num_scored_shots else 0.0 + print( + f"progress shots_done={shots_done}/{len(shots)} errors_so_far={num_errors} " + f"low_conf_so_far={num_low_confidence} scored_shots_so_far={num_scored_shots} " + f"error_rate_so_far={error_rate_so_far:.6f} elapsed_total_seconds={total_elapsed:.6f} " + f"kept_states_min={_format_summary_int(kept_state_summary.minimum)} " + f"kept_states_median={_format_summary_float(kept_state_summary.median)} " + f"kept_states_mean={_format_summary_float(kept_state_summary.mean)} " + f"kept_states_max={_format_summary_int(kept_state_summary.maximum)}", + file=log_stream, + ) + + if args.print_per_shot: + print( + f"shot={shot_index} triggered_detectors={triggered_dets} " + f"predicted_logical={result.predicted_logical} actual_logical={shot.actual_logical} " + f"success={success} certified={result.certified} " + f"margin={result.margin:.6e} discarded_mass={result.discarded_mass:.6e} " + f"kept_states_min={_format_summary_int(kept_state_summary.minimum)} " + f"kept_states_median={_format_summary_float(kept_state_summary.median)} " + f"kept_states_mean={_format_summary_float(kept_state_summary.mean)} " + f"kept_states_max={_format_summary_int(kept_state_summary.maximum)} " + f"elapsed_seconds={result.elapsed_seconds:.6f}", + file=log_stream, + ) + + if args.out_file: + output_path, copy_to_stdout = _resolve_stdout_path_if_needed( + args.out_file, + temp_dir=temp_dir, + stem="predictions_out", + ) + prediction_data = np.zeros((len(predictions), circuit.num_observables), dtype=np.bool_) + for shot_index, predicted_logical in enumerate(predictions): + prediction_data[shot_index, 0] = bool(predicted_logical) if predicted_logical is not None else False + + if args.out_format == "ptb64" and len(prediction_data) % 64 != 0: + raise ValueError("The ptb64 format requires the number of shots to be a multiple of 64.") + + stim.write_shot_data_file( + data=prediction_data, + path=output_path, + format=args.out_format, + num_measurements=0, + num_detectors=0, + num_observables=circuit.num_observables, + ) + if copy_to_stdout: + _copy_file_to_stdout(output_path) + if num_low_confidence: + print( + f"warning: wrote {num_low_confidence} low-confidence predictions as L0=0 because Stim result " + "files can only store bits, not unknown values.", + file=log_stream, + ) + + kept_state_summary = kept_state_accumulator.summary() + + print(f"Mean Triggered Dets: {total_triggered / max(1, len(shots)):.2f}", file=log_stream) + print(f"Max Width: {max_width_seen}", file=log_stream) + print(f"{'Kept States/Fault Min:':<26}{_format_summary_int(kept_state_summary.minimum)}", file=log_stream) + print(f"{'Kept States/Fault Median:':<26}{_format_summary_float(kept_state_summary.median)}", file=log_stream) + print(f"{'Kept States/Fault Mean:':<26}{_format_summary_float(kept_state_summary.mean)}", file=log_stream) + print(f"{'Kept States/Fault Max:':<26}{_format_summary_int(kept_state_summary.maximum)}", file=log_stream) + print(f"Certified Shots: {num_certified}", file=log_stream) + print(f"Low Confidence: {num_low_confidence}", file=log_stream) + print(f"Truth-Labeled Shots: {num_truth_shots}", file=log_stream) + print(f"Scored Shots: {num_scored_shots}", file=log_stream) + if num_truth_shots: + print(f"Logical Errors: {num_errors}", file=log_stream) + else: + print("Logical Errors: n/a", file=log_stream) + print(f"Total Seconds: {total_elapsed:.6f}", file=log_stream) + print(f"Mean Seconds/Shot: {total_elapsed / max(1, len(shots)):.6f}", file=log_stream) + + return ExperimentSummary( + predictions=predictions, + num_certified=num_certified, + num_low_confidence=num_low_confidence, + num_errors=num_errors, + num_truth_shots=num_truth_shots, + num_scored_shots=num_scored_shots, + total_elapsed=total_elapsed, + total_triggered=total_triggered, + max_width_seen=max_width_seen, + kept_state_summary=kept_state_summary, + ) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Run trellis beam decoding ranked by mass minus a detcost-style future penalty, " + "with optional adaptive threshold pruning and Stim-compatible shot-data I/O options." + ), + allow_abbrev=False, + ) + parser.add_argument("--circuit", required=True, help="Path to the .stim circuit file.") + parser.add_argument( + "--beam", + type=int, + default=1000, + help="Beam width cutoff used when --beam-prune-mode=fixed.", + ) + parser.add_argument( + "--beam-prune-mode", + choices=BEAM_PRUNE_MODES, + default="fixed", + help=( + "Beam pruning rule: fixed keeps the top --beam states, delta keeps all states within " + "--beam-score-delta of the best rank score, and mass keeps a rank-sorted prefix whose " + "retained normalized mass reaches 1-epsilon." + ), + ) + parser.add_argument( + "--beam-score-delta", + type=float, + default=None, + help=( + "For --beam-prune-mode=delta, keep every state whose rank score is within this additive " + "gap of the best state's rank score." + ), + ) + parser.add_argument( + "--beam-mass-epsilon", + type=float, + default=None, + help=( + "For --beam-prune-mode=mass, keep the smallest rank-sorted prefix whose retained " + "normalized mass is at least 1 - epsilon." + ), + ) + parser.add_argument( + "--beam-retained-mass", + type=float, + default=None, + help=( + "For --beam-prune-mode=mass, equivalent to setting --beam-mass-epsilon to " + "1 - retained_mass." + ), + ) + parser.add_argument( + "--beam-hard-cap", + type=int, + default=None, + help=( + "Optional hard cap on the number of states retained after delta or mass thresholding. " + "Ignored in fixed mode." + ), + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=None, + help="Number of sampled shots. Defaults to 1 unless --in is provided.", + ) + parser.add_argument("--sample-seed", type=int, default=None, help="Stim sampler seed.") + parser.add_argument( + "--shot-range-begin", + type=int, + default=0, + help=( + "If both --shot-range-begin and --shot-range-end are 0, decode all available shots. " + "Otherwise only decode shots in [begin, end)." + ), + ) + parser.add_argument( + "--shot-range-end", + type=int, + default=0, + help=( + "If both --shot-range-begin and --shot-range-end are 0, decode all available shots. " + "Otherwise only decode shots in [begin, end)." + ), + ) + parser.add_argument( + "--in", + dest="in_file", + default="", + help="File to read detection events from (use - for stdin).", + ) + parser.add_argument( + "--in-format", + "--in_format", + dest="in_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file read by --in ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--in-includes-appended-observables", + "--in_includes_appended_observables", + dest="in_includes_appended_observables", + action="store_true", + help="Assume the observable flips are appended to each shot in --in.", + ) + parser.add_argument( + "--obs-in", + "--obs_in", + dest="obs_in_file", + default="", + help="File to read observable flips from (use - for stdin).", + ) + parser.add_argument( + "--obs-in-format", + "--obs_in_format", + dest="obs_in_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file read by --obs-in ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--out", + dest="out_file", + default="", + help="File to write predicted observable flips to (use - for stdout).", + ) + parser.add_argument( + "--out-format", + "--out_format", + dest="out_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file written by --out ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--print-per-shot", + action="store_true", + help="Print a detailed line per decoded shot.", + ) + args = parser.parse_args() + + if args.sample_num_shots is None: + # Preserve the original script's one-shot default while still allowing + # file input without requiring --sample-num-shots 0. + args.sample_num_shots = 0 if args.in_file else 1 + + if args.beam <= 0: + raise ValueError("--beam must be positive.") + if args.beam_hard_cap is not None and args.beam_hard_cap <= 0: + raise ValueError("--beam-hard-cap must be positive when provided.") + if args.beam_score_delta is not None: + if math.isnan(args.beam_score_delta) or args.beam_score_delta < 0.0: + raise ValueError("--beam-score-delta must be a non-negative number.") + if args.beam_mass_epsilon is not None: + if math.isnan(args.beam_mass_epsilon) or not (0.0 <= args.beam_mass_epsilon < 1.0): + raise ValueError("--beam-mass-epsilon must satisfy 0 <= epsilon < 1.") + if args.beam_retained_mass is not None: + if math.isnan(args.beam_retained_mass) or not (0.0 <= args.beam_retained_mass <= 1.0): + raise ValueError("--beam-retained-mass must satisfy 0 <= retained_mass <= 1.") + if args.beam_mass_epsilon is not None: + raise ValueError("Choose at most one of --beam-mass-epsilon and --beam-retained-mass.") + args.beam_mass_epsilon = 1.0 - args.beam_retained_mass + if args.sample_num_shots < 0: + raise ValueError("--sample-num-shots must be non-negative.") + if args.sample_seed is not None and args.sample_seed < 0: + raise ValueError("--sample-seed must be non-negative.") + if args.shot_range_begin < 0 or args.shot_range_end < 0: + raise ValueError("--shot-range-begin and --shot-range-end must be non-negative.") + if args.shot_range_end < args.shot_range_begin: + raise ValueError("Provided shot range must satisfy --shot-range-end >= --shot-range-begin.") + if args.in_includes_appended_observables and args.obs_in_file: + raise ValueError( + "Choose either --in-includes-appended-observables or --obs-in, not both." + ) + if args.obs_in_file and not args.in_file: + raise ValueError("Cannot load observable flips from --obs-in without also providing --in.") + if args.in_file == "-" and args.obs_in_file == "-": + raise ValueError("At most one of --in and --obs-in may read from stdin.") + + if args.beam_prune_mode == "fixed": + if args.beam_score_delta is not None: + raise ValueError("--beam-score-delta is only valid with --beam-prune-mode=delta.") + if args.beam_mass_epsilon is not None: + raise ValueError( + "--beam-mass-epsilon/--beam-retained-mass are only valid with --beam-prune-mode=mass." + ) + if args.beam_hard_cap is not None: + raise ValueError("--beam-hard-cap is only meaningful with adaptive pruning modes.") + elif args.beam_prune_mode == "delta": + if args.beam_score_delta is None: + raise ValueError("--beam-prune-mode=delta requires --beam-score-delta.") + if args.beam_mass_epsilon is not None: + raise ValueError( + "--beam-mass-epsilon/--beam-retained-mass are not valid with --beam-prune-mode=delta." + ) + elif args.beam_prune_mode == "mass": + if args.beam_mass_epsilon is None: + raise ValueError( + "--beam-prune-mode=mass requires --beam-mass-epsilon or --beam-retained-mass." + ) + if args.beam_score_delta is not None: + raise ValueError("--beam-score-delta is not valid with --beam-prune-mode=mass.") + else: + raise ValueError(f"Unsupported --beam-prune-mode {args.beam_prune_mode!r}.") + + num_shot_sources = int(args.sample_num_shots > 0) + int(bool(args.in_file)) + if num_shot_sources != 1: + raise ValueError("Requires exactly one source of shots: either --sample-num-shots > 0 or --in.") + + return args + + +if __name__ == "__main__": + run_experiment(_parse_args()) diff --git a/src/py/astar/trellis_beam_iterative_forward_backward.py b/src/py/astar/trellis_beam_iterative_forward_backward.py new file mode 100644 index 0000000..6217639 --- /dev/null +++ b/src/py/astar/trellis_beam_iterative_forward_backward.py @@ -0,0 +1,1249 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import math +import shutil +import sys +import tempfile +import time +from collections import Counter +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import stim + + +STIM_RESULT_FORMATS = ("01", "b8", "r8", "ptb64", "hits", "dets") +STIM_RESULT_FORMATS_HELP = "/".join(STIM_RESULT_FORMATS) + + +@dataclass(frozen=True) +class Fault: + q: float + p: float + delta_scale: float + det_mask: int + likelihood_cost: float + + +@dataclass(frozen=True) +class DirectionalModel: + faults: tuple[Fault, ...] + retiring_masks: tuple[int, ...] + frontier_masks_after_step: tuple[int, ...] + future_detcost: tuple[tuple[float, ...], ...] + cut_after_step: tuple[int, ...] + direction_name: str + + +@dataclass(frozen=True) +class FrontierSnapshot: + local_states: tuple[int, ...] + masses: tuple[float, ...] + kept_total_mass: float + discarded_mass: float + one_masses: tuple[float, ...] + + +@dataclass(frozen=True) +class DecodePassSummary: + pass_index: int + direction: str + final_delta: float + discarded_mass: float + elapsed_seconds: float + exact_upper_hits: int + exact_upper_misses: int + marginal_tighter_count: int + opposite_selected_count: int + detcost_selected_count: int + mean_frontier_width: float + max_frontier_width: int + mean_beam_size: float + max_beam_size: int + opposite_available_steps: int + + +@dataclass(frozen=True) +class DecoderModel: + faults: tuple[Fault, ...] + forward_model: DirectionalModel + backward_model: DirectionalModel + frontier_masks_by_cut: tuple[int, ...] + frontier_detector_ids_by_cut: tuple[tuple[int, ...], ...] + frontier_global_bit_masks_by_cut: tuple[tuple[int, ...], ...] + all_possible_dets_mask: int + max_width: int + repeated_frontier_mask_count: int + max_frontier_mask_repeat: int + + +@dataclass(frozen=True) +class BeamDecodeResult: + predicted_logical: bool | None + certified: bool + margin: float + discarded_mass: float + max_width: int + elapsed_seconds: float + pass_summaries: tuple[DecodePassSummary, ...] + + +@dataclass(frozen=True) +class DecodingShot: + det_mask: int + actual_logical: bool | None + + +@dataclass(frozen=True) +class ExperimentSummary: + predictions: list[bool | None] + num_certified: int + num_low_confidence: int + num_errors: int + num_truth_shots: int + num_scored_shots: int + total_elapsed: float + total_triggered: int + max_width_seen: int + + +@dataclass(frozen=True) +class _BeamPassOutcome: + failed: bool + final_delta: float + discarded_mass: float + snapshots: tuple[FrontierSnapshot | None, ...] + summary: DecodePassSummary + + +def _likelihood_cost(probability: float) -> float: + if probability <= 0.0: + return math.inf + if probability >= 1.0: + return 0.0 + return -math.log(probability / (1.0 - probability)) + + +def _detectors_from_mask(mask: int) -> list[int]: + detectors: list[int] = [] + while mask: + low_bit = mask & -mask + detectors.append(low_bit.bit_length() - 1) + mask ^= low_bit + return detectors + + +def _mask_from_bool_row(row: np.ndarray) -> int: + mask = 0 + for index in np.flatnonzero(row): + mask |= 1 << int(index) + return mask + + +def _future_detcost_by_detector(faults: tuple[Fault, ...], num_detectors: int) -> tuple[tuple[float, ...], ...]: + future_detcost: list[list[float]] = [[math.inf] * num_detectors for _ in range(len(faults) + 1)] + next_row = future_detcost[-1] + for fault_index in range(len(faults) - 1, -1, -1): + row = next_row.copy() + fault = faults[fault_index] + det_count = fault.det_mask.bit_count() + if det_count: + ecost = fault.likelihood_cost / det_count + for det_id in _detectors_from_mask(fault.det_mask): + if ecost < row[det_id]: + row[det_id] = ecost + future_detcost[fault_index] = row + next_row = row + return tuple(tuple(row) for row in future_detcost) + + +def _build_directional_model( + *, + faults_in_order: tuple[Fault, ...], + num_detectors: int, + cut_after_step: tuple[int, ...], + direction_name: str, +) -> DirectionalModel: + last_seen_index: dict[int, int] = {} + for fault_index, fault in enumerate(faults_in_order): + for det_id in _detectors_from_mask(fault.det_mask): + last_seen_index[det_id] = fault_index + + retiring_masks = [0] * len(faults_in_order) + for det_id, fault_index in last_seen_index.items(): + retiring_masks[fault_index] |= 1 << det_id + + frontier_masks_after_step = [0] * (len(faults_in_order) + 1) + active_mask = 0 + for fault_index, fault in enumerate(faults_in_order): + active_mask |= fault.det_mask + active_mask &= ~retiring_masks[fault_index] + frontier_masks_after_step[fault_index + 1] = active_mask + + return DirectionalModel( + faults=faults_in_order, + retiring_masks=tuple(retiring_masks), + frontier_masks_after_step=tuple(frontier_masks_after_step), + future_detcost=_future_detcost_by_detector(faults_in_order, num_detectors), + cut_after_step=cut_after_step, + direction_name=direction_name, + ) + + +def _build_decoder_model(circuit: stim.Circuit) -> DecoderModel: + dem = circuit.detector_error_model(decompose_errors=False).flattened() + + faults: list[Fault] = [] + all_possible_dets_mask = 0 + + for inst in dem: + if inst.type != "error": + continue + + p = float(inst.args_copy()[0]) + det_mask = 0 + flip_l0 = 0 + for target in inst.targets_copy(): + if target.is_separator(): + continue + if target.is_relative_detector_id(): + det_mask ^= 1 << target.val + elif target.is_logical_observable_id() and target.val == 0: + flip_l0 ^= 1 + + faults.append( + Fault( + q=1.0 - p, + p=p, + delta_scale=(-p if flip_l0 else p), + det_mask=det_mask, + likelihood_cost=_likelihood_cost(p), + ) + ) + all_possible_dets_mask |= det_mask + + frozen_faults = tuple(faults) + num_faults = len(frozen_faults) + + forward_model = _build_directional_model( + faults_in_order=frozen_faults, + num_detectors=circuit.num_detectors, + cut_after_step=tuple(range(num_faults + 1)), + direction_name="forward", + ) + backward_model = _build_directional_model( + faults_in_order=tuple(reversed(frozen_faults)), + num_detectors=circuit.num_detectors, + cut_after_step=tuple(num_faults - step for step in range(num_faults + 1)), + direction_name="backward", + ) + + for cut in range(num_faults + 1): + forward_mask = forward_model.frontier_masks_after_step[cut] + backward_mask = backward_model.frontier_masks_after_step[num_faults - cut] + if forward_mask != backward_mask: + raise ValueError( + "Internal frontier alignment check failed: the forward and backward cuts did not produce the " + f"same frontier detector set at cut {cut}." + ) + + frontier_masks_by_cut = forward_model.frontier_masks_after_step + frontier_detector_ids_by_cut = tuple( + tuple(_detectors_from_mask(mask)) for mask in frontier_masks_by_cut + ) + frontier_global_bit_masks_by_cut = tuple( + tuple(1 << det_id for det_id in detector_ids) + for detector_ids in frontier_detector_ids_by_cut + ) + repeated_frontier_counts = Counter(frontier_masks_by_cut) + repeated_frontier_mask_count = sum(1 for count in repeated_frontier_counts.values() if count > 1) + max_frontier_mask_repeat = max(repeated_frontier_counts.values(), default=0) + max_width = max((mask.bit_count() for mask in frontier_masks_by_cut), default=0) + + return DecoderModel( + faults=frozen_faults, + forward_model=forward_model, + backward_model=backward_model, + frontier_masks_by_cut=frontier_masks_by_cut, + frontier_detector_ids_by_cut=frontier_detector_ids_by_cut, + frontier_global_bit_masks_by_cut=frontier_global_bit_masks_by_cut, + all_possible_dets_mask=all_possible_dets_mask, + max_width=max_width, + repeated_frontier_mask_count=repeated_frontier_mask_count, + max_frontier_mask_repeat=max_frontier_mask_repeat, + ) + + +def _detcost_penalty(mismatch_mask: int, future_detcost: tuple[float, ...]) -> float: + total = 0.0 + pending = mismatch_mask + + while pending: + low_bit = pending & -pending + detector = low_bit.bit_length() - 1 + pending ^= low_bit + + best = future_detcost[detector] + if best == math.inf: + return math.inf + total += best + + return total + + +def _compress_global_state_to_local_state(global_state: int, global_bit_masks: tuple[int, ...]) -> int: + local_state = 0 + for local_index, global_bit in enumerate(global_bit_masks): + if global_state & global_bit: + local_state |= 1 << local_index + return local_state + + +def _record_frontier_snapshot( + *, + model: DecoderModel, + cut: int, + beam: list[tuple[int, float, float]], + discarded_mass: float, +) -> FrontierSnapshot: + global_bit_masks = model.frontier_global_bit_masks_by_cut[cut] + one_masses = [0.0] * len(global_bit_masks) + local_states: list[int] = [] + masses: list[float] = [] + + for state, total, _ in beam: + local_state = 0 + for local_index, global_bit in enumerate(global_bit_masks): + if state & global_bit: + local_state |= 1 << local_index + one_masses[local_index] += total + local_states.append(local_state) + masses.append(total) + + return FrontierSnapshot( + local_states=tuple(local_states), + masses=tuple(masses), + kept_total_mass=sum(masses), + discarded_mass=discarded_mass, + one_masses=tuple(one_masses), + ) + + +def _opposite_pass_cost_lower_bound( + *, + current_state: int, + live_target_mask: int, + snapshot: FrontierSnapshot, + snapshot_lookup: dict[int, float], + frontier_global_bit_masks: tuple[int, ...], +) -> tuple[float, bool, bool]: + if not frontier_global_bit_masks: + return 0.0, False, False + + compatible_other_state = current_state ^ live_target_mask + local_compatible_state = _compress_global_state_to_local_state( + global_state=compatible_other_state, + global_bit_masks=frontier_global_bit_masks, + ) + + # The opposite pass only records lower bounds on the surviving frontier-state + # masses, together with an upper bound on all omitted mass. Therefore + # exact_mass + discarded_mass is still an admissible upper bound on the true + # compatible-state mass, and -log(upper_bound) is an admissible lower bound + # on the remaining cost. + missing_upper_bound = min(1.0, max(0.0, snapshot.discarded_mass)) + exact_mass = snapshot_lookup.get(local_compatible_state) + exact_hit = exact_mass is not None + if exact_mass is None: + exact_upper_bound = missing_upper_bound + else: + exact_upper_bound = min(1.0, exact_mass + missing_upper_bound) + + kept_total_mass = min(1.0, max(0.0, snapshot.kept_total_mass)) + marginal_upper_bound = 1.0 + for local_index, observed_one_mass in enumerate(snapshot.one_masses): + if (local_compatible_state >> local_index) & 1: + upper_bound = min(1.0, observed_one_mass + missing_upper_bound) + else: + observed_zero_mass = max(0.0, kept_total_mass - observed_one_mass) + upper_bound = min(1.0, observed_zero_mass + missing_upper_bound) + if upper_bound < marginal_upper_bound: + marginal_upper_bound = upper_bound + + used_marginal_bound = marginal_upper_bound < exact_upper_bound + compatible_upper_bound = min(exact_upper_bound, marginal_upper_bound) + if compatible_upper_bound <= 0.0: + return math.inf, exact_hit, used_marginal_bound + return -math.log(compatible_upper_bound), exact_hit, used_marginal_bound + + +def _run_beam_pass( + *, + model: DecoderModel, + directional_model: DirectionalModel, + actual_dets_mask: int, + L: int, + pass_index: int, + opposite_snapshots: tuple[FrontierSnapshot | None, ...] | None, +) -> _BeamPassOutcome: + pass_start_time = time.perf_counter() + beam: list[tuple[int, float, float]] = [(0, 1.0, 1.0)] + discarded_mass = 0.0 + + num_faults = len(directional_model.faults) + snapshots: list[FrontierSnapshot | None] = [None] * (num_faults + 1) + initial_cut = directional_model.cut_after_step[0] + snapshots[initial_cut] = _record_frontier_snapshot( + model=model, + cut=initial_cut, + beam=beam, + discarded_mass=discarded_mass, + ) + + exact_upper_hits = 0 + exact_upper_misses = 0 + marginal_tighter_count = 0 + opposite_selected_count = 0 + detcost_selected_count = 0 + opposite_available_steps = 0 + frontier_width_total = len(model.frontier_detector_ids_by_cut[initial_cut]) + frontier_width_steps = 1 + max_frontier_width = len(model.frontier_detector_ids_by_cut[initial_cut]) + beam_size_total = len(beam) + beam_size_steps = 1 + max_beam_size = len(beam) + + for fault_index, fault in enumerate(directional_model.faults): + collapsed_probs: dict[int, list[float]] = {} + total_mass = 0.0 + retiring_mask = directional_model.retiring_masks[fault_index] + + if retiring_mask == 0: + for state, total, delta in beam: + absent_total = total * fault.q + absent_delta = delta * fault.q + total_mass += absent_total + entry = collapsed_probs.get(state) + if entry is None: + collapsed_probs[state] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + toggled = state ^ fault.det_mask + present_total = total * fault.p + present_delta = delta * fault.delta_scale + total_mass += present_total + entry = collapsed_probs.get(toggled) + if entry is None: + collapsed_probs[toggled] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + else: + expected_bits = actual_dets_mask & retiring_mask + keep_mask = ~retiring_mask + for state, total, delta in beam: + absent_total = total * fault.q + absent_delta = delta * fault.q + if (state & retiring_mask) == expected_bits: + shrunk = state & keep_mask + total_mass += absent_total + entry = collapsed_probs.get(shrunk) + if entry is None: + collapsed_probs[shrunk] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + toggled = state ^ fault.det_mask + present_total = total * fault.p + present_delta = delta * fault.delta_scale + if (toggled & retiring_mask) == expected_bits: + shrunk = toggled & keep_mask + total_mass += present_total + entry = collapsed_probs.get(shrunk) + if entry is None: + collapsed_probs[shrunk] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + + if total_mass == 0.0: + summary = DecodePassSummary( + pass_index=pass_index, + direction=directional_model.direction_name, + final_delta=0.0, + discarded_mass=discarded_mass, + elapsed_seconds=time.perf_counter() - pass_start_time, + exact_upper_hits=exact_upper_hits, + exact_upper_misses=exact_upper_misses, + marginal_tighter_count=marginal_tighter_count, + opposite_selected_count=opposite_selected_count, + detcost_selected_count=detcost_selected_count, + mean_frontier_width=frontier_width_total / max(1, frontier_width_steps), + max_frontier_width=max_frontier_width, + mean_beam_size=beam_size_total / max(1, beam_size_steps), + max_beam_size=max_beam_size, + opposite_available_steps=opposite_available_steps, + ) + return _BeamPassOutcome( + failed=True, + final_delta=0.0, + discarded_mass=discarded_mass, + snapshots=tuple(snapshots), + summary=summary, + ) + + next_cut = directional_model.cut_after_step[fault_index + 1] + frontier_mask = directional_model.frontier_masks_after_step[fault_index + 1] + live_target_mask = actual_dets_mask & frontier_mask + next_future_detcost = directional_model.future_detcost[fault_index + 1] + frontier_global_bit_masks = model.frontier_global_bit_masks_by_cut[next_cut] + + opposite_snapshot = None if opposite_snapshots is None else opposite_snapshots[next_cut] + opposite_lookup: dict[int, float] | None = None + if opposite_snapshot is not None: + opposite_available_steps += 1 + opposite_lookup = {state: mass for state, mass in zip(opposite_snapshot.local_states, opposite_snapshot.masses)} + + ranked_states: list[tuple[float, float, int, float]] = [] + for state, (total, delta) in collapsed_probs.items(): + mismatch_mask = state ^ live_target_mask + heuristic_cost = _detcost_penalty( + mismatch_mask=mismatch_mask, + future_detcost=next_future_detcost, + ) + if opposite_snapshot is not None and opposite_lookup is not None: + opposite_cost, exact_hit, used_marginal_bound = _opposite_pass_cost_lower_bound( + current_state=state, + live_target_mask=live_target_mask, + snapshot=opposite_snapshot, + snapshot_lookup=opposite_lookup, + frontier_global_bit_masks=frontier_global_bit_masks, + ) + if exact_hit: + exact_upper_hits += 1 + else: + exact_upper_misses += 1 + if used_marginal_bound: + marginal_tighter_count += 1 + if opposite_cost > heuristic_cost: + heuristic_cost = opposite_cost + opposite_selected_count += 1 + else: + detcost_selected_count += 1 + else: + detcost_selected_count += 1 + + if heuristic_cost == math.inf: + rank_score = -math.inf + else: + rank_score = math.log(total) - heuristic_cost + ranked_states.append((rank_score, total, state, delta)) + + dropped_mass = 0.0 + if len(ranked_states) > L: + ranked_states.sort(reverse=True) + kept = ranked_states[:L] + beam = [(state, total, delta) for _, total, state, delta in kept] + kept_mass = sum(total for _, total, _, _ in kept) + dropped_mass = total_mass - kept_mass + else: + beam = [(state, total, delta) for _, total, state, delta in ranked_states] + + inv_total_mass = 1.0 / total_mass + discarded_mass = (discarded_mass + dropped_mass) * inv_total_mass + beam = [ + (state, total * inv_total_mass, delta * inv_total_mass) + for state, total, delta in beam + ] + + snapshots[next_cut] = _record_frontier_snapshot( + model=model, + cut=next_cut, + beam=beam, + discarded_mass=discarded_mass, + ) + + frontier_width = len(model.frontier_detector_ids_by_cut[next_cut]) + frontier_width_total += frontier_width + frontier_width_steps += 1 + max_frontier_width = max(max_frontier_width, frontier_width) + beam_size = len(beam) + beam_size_total += beam_size + beam_size_steps += 1 + max_beam_size = max(max_beam_size, beam_size) + + _, _, final_delta = next((entry for entry in beam if entry[0] == 0), (0, 0.0, 0.0)) + summary = DecodePassSummary( + pass_index=pass_index, + direction=directional_model.direction_name, + final_delta=final_delta, + discarded_mass=discarded_mass, + elapsed_seconds=time.perf_counter() - pass_start_time, + exact_upper_hits=exact_upper_hits, + exact_upper_misses=exact_upper_misses, + marginal_tighter_count=marginal_tighter_count, + opposite_selected_count=opposite_selected_count, + detcost_selected_count=detcost_selected_count, + mean_frontier_width=frontier_width_total / max(1, frontier_width_steps), + max_frontier_width=max_frontier_width, + mean_beam_size=beam_size_total / max(1, beam_size_steps), + max_beam_size=max_beam_size, + opposite_available_steps=opposite_available_steps, + ) + return _BeamPassOutcome( + failed=False, + final_delta=final_delta, + discarded_mass=discarded_mass, + snapshots=tuple(snapshots), + summary=summary, + ) + + +def decode_beam_search_iterative( + model: DecoderModel, + actual_dets_mask: int, + L: int, + *, + num_passes: int, +) -> BeamDecodeResult: + start_time = time.perf_counter() + + if (actual_dets_mask & ~model.all_possible_dets_mask) != 0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=0.0, + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + pass_summaries=(), + ) + + if num_passes <= 0: + raise ValueError("num_passes must be positive.") + + opposite_snapshots: tuple[FrontierSnapshot | None, ...] | None = None + pass_summaries: list[DecodePassSummary] = [] + last_outcome: _BeamPassOutcome | None = None + + for pass_offset in range(num_passes): + pass_index = pass_offset + 1 + directional_model = model.forward_model if pass_offset % 2 == 0 else model.backward_model + last_outcome = _run_beam_pass( + model=model, + directional_model=directional_model, + actual_dets_mask=actual_dets_mask, + L=L, + pass_index=pass_index, + opposite_snapshots=opposite_snapshots, + ) + pass_summaries.append(last_outcome.summary) + if last_outcome.failed: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=last_outcome.discarded_mass, + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + pass_summaries=tuple(pass_summaries), + ) + opposite_snapshots = last_outcome.snapshots + + assert last_outcome is not None + final_delta = last_outcome.final_delta + margin = abs(final_delta) + certified = margin > last_outcome.discarded_mass + + if final_delta == 0.0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=margin, + discarded_mass=last_outcome.discarded_mass, + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + pass_summaries=tuple(pass_summaries), + ) + return BeamDecodeResult( + predicted_logical=final_delta < 0.0, + certified=certified, + margin=margin, + discarded_mass=last_outcome.discarded_mass, + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + pass_summaries=tuple(pass_summaries), + ) + + +def decode_beam_search_detcost_ranked( + model: DecoderModel, + actual_dets_mask: int, + L: int, +) -> BeamDecodeResult: + return decode_beam_search_iterative( + model=model, + actual_dets_mask=actual_dets_mask, + L=L, + num_passes=1, + ) + + +def _as_bool_2d(data: np.ndarray, *, expected_cols: int, description: str) -> np.ndarray: + arr = np.asarray(data) + if arr.ndim != 2: + raise ValueError(f"Expected {description} to be a 2D array but got shape {arr.shape!r}.") + if arr.shape[1] != expected_cols: + raise ValueError( + f"Expected {description} to have {expected_cols} columns but got {arr.shape[1]}." + ) + if arr.dtype != np.bool_: + arr = arr.astype(np.bool_, copy=False) + return arr + + +def _sample_shot_arrays( + circuit: stim.Circuit, + *, + shots: int, + seed: int | None, +) -> tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets, obs = sampler.sample(shots=shots, separate_observables=True) + return ( + _as_bool_2d(dets, expected_cols=circuit.num_detectors, description="sampled detector data"), + _as_bool_2d(obs, expected_cols=circuit.num_observables, description="sampled observable data"), + ) + + +def _read_detector_shot_arrays( + *, + path: str, + fmt: str, + num_detectors: int, + num_observables: int, +) -> tuple[np.ndarray, np.ndarray | None]: + common_kwargs = dict( + path=path, + format=fmt, + bit_packed=False, + num_measurements=0, + num_detectors=num_detectors, + num_observables=num_observables, + ) + + if num_observables: + try: + dets, obs = stim.read_shot_data_file(**common_kwargs, separate_observables=True) + return ( + _as_bool_2d(dets, expected_cols=num_detectors, description="input detector data"), + _as_bool_2d(obs, expected_cols=num_observables, description="appended observable data"), + ) + except TypeError: + flat = stim.read_shot_data_file(**common_kwargs) + flat = _as_bool_2d( + flat, + expected_cols=num_detectors + num_observables, + description="combined detector/observable input data", + ) + return flat[:, :num_detectors], flat[:, num_detectors:] + + flat = stim.read_shot_data_file(**common_kwargs) + return _as_bool_2d(flat, expected_cols=num_detectors, description="input detector data"), None + + +def _read_observable_shot_array(*, path: str, fmt: str, num_observables: int) -> np.ndarray: + obs = stim.read_shot_data_file( + path=path, + format=fmt, + bit_packed=False, + num_measurements=0, + num_detectors=0, + num_observables=num_observables, + ) + return _as_bool_2d(obs, expected_cols=num_observables, description="observable input data") + + +def _apply_shot_range( + dets: np.ndarray, + obs: np.ndarray | None, + *, + shot_range_begin: int, + shot_range_end: int, +) -> tuple[np.ndarray, np.ndarray | None]: + if not (shot_range_begin or shot_range_end): + return dets, obs + + if shot_range_end < shot_range_begin: + raise ValueError("Provided shot range must satisfy --shot-range-end >= --shot-range-begin.") + if shot_range_end > len(dets): + raise ValueError( + f"Shot range end {shot_range_end} is past the end of the shot data (size {len(dets)})." + ) + + dets = dets[shot_range_begin:shot_range_end] + if obs is not None: + obs = obs[shot_range_begin:shot_range_end] + return dets, obs + + +def _shots_from_arrays(dets: np.ndarray, obs: np.ndarray | None) -> list[DecodingShot]: + shots: list[DecodingShot] = [] + for shot_index in range(dets.shape[0]): + actual_logical = None if obs is None else bool(obs[shot_index, 0]) + shots.append( + DecodingShot( + det_mask=_mask_from_bool_row(dets[shot_index]), + actual_logical=actual_logical, + ) + ) + return shots + + +def _resolve_stdin_path_if_needed(path: str, *, temp_dir: str, stem: str) -> str: + if path != "-": + return path + temp_path = str(Path(temp_dir) / f"{stem}.bin") + with open(temp_path, "wb") as f: + f.write(sys.stdin.buffer.read()) + return temp_path + + +def _resolve_stdout_path_if_needed(path: str, *, temp_dir: str, stem: str) -> tuple[str, bool]: + if path != "-": + return path, False + return str(Path(temp_dir) / f"{stem}.bin"), True + + +def _copy_file_to_stdout(path: str) -> None: + sys.stdout.flush() + with open(path, "rb") as f: + shutil.copyfileobj(f, sys.stdout.buffer) + sys.stdout.buffer.flush() + + +def _load_shots( + circuit: stim.Circuit, + args: argparse.Namespace, + *, + temp_dir: str, +) -> list[DecodingShot]: + if args.in_file: + in_path = _resolve_stdin_path_if_needed(args.in_file, temp_dir=temp_dir, stem="shots_in") + appended_obs_count = circuit.num_observables if args.in_includes_appended_observables else 0 + dets, obs = _read_detector_shot_arrays( + path=in_path, + fmt=args.in_format, + num_detectors=circuit.num_detectors, + num_observables=appended_obs_count, + ) + + if args.obs_in_file: + obs_in_path = _resolve_stdin_path_if_needed(args.obs_in_file, temp_dir=temp_dir, stem="obs_in") + obs = _read_observable_shot_array( + path=obs_in_path, + fmt=args.obs_in_format, + num_observables=circuit.num_observables, + ) + if len(obs) != len(dets): + raise ValueError("Observable input ended before, or after, the detector shot data.") + else: + dets, obs = _sample_shot_arrays(circuit, shots=args.sample_num_shots, seed=args.sample_seed) + + dets, obs = _apply_shot_range( + dets, + obs, + shot_range_begin=args.shot_range_begin, + shot_range_end=args.shot_range_end, + ) + return _shots_from_arrays(dets, obs) + + +def _print_run_header( + *, + circuit: stim.Circuit, + model: DecoderModel, + args: argparse.Namespace, + num_shots: int, + log_stream, +) -> None: + print(f"Running on circuit {args.circuit}", file=log_stream) + print(f"Total Detectors: {circuit.num_detectors}", file=log_stream) + print(f"Total Observables: {circuit.num_observables}", file=log_stream) + print(f"Beam: {args.beam}", file=log_stream) + print(f"Num Passes: {args.num_passes}", file=log_stream) + print("Frontier Matching: keyed by cut index (forward/backward cut check verified)", file=log_stream) + if args.num_passes > 1: + print( + f"Repeated Frontiers: {model.repeated_frontier_mask_count} repeated detector-set masks; " + f"max repeat={model.max_frontier_mask_repeat}", + file=log_stream, + ) + if args.in_file: + print(f"Shot Input: {args.in_file}", file=log_stream) + print(f"Shot Input Format: {args.in_format}", file=log_stream) + if args.in_includes_appended_observables: + print("Observable Input: appended to --in", file=log_stream) + elif args.obs_in_file: + print(f"Observable Input: {args.obs_in_file}", file=log_stream) + print(f"Observable Format: {args.obs_in_format}", file=log_stream) + else: + print("Observable Input: none", file=log_stream) + else: + print(f"Sample Seed: {args.sample_seed}", file=log_stream) + print(f"Requested Shots: {args.sample_num_shots}", file=log_stream) + if args.shot_range_begin or args.shot_range_end: + print( + f"Shot Range: [{args.shot_range_begin}, {args.shot_range_end})", + file=log_stream, + ) + print(f"Num Shots: {num_shots}", file=log_stream) + + +def _print_pass_summary( + *, + pass_summary: DecodePassSummary, + log_stream, +) -> None: + print( + f" pass={pass_summary.pass_index} direction={pass_summary.direction} " + f"final_delta={pass_summary.final_delta:.6e} discarded_mass={pass_summary.discarded_mass:.6e} " + f"exact_hits={pass_summary.exact_upper_hits} exact_misses={pass_summary.exact_upper_misses} " + f"marginal_tighter={pass_summary.marginal_tighter_count} " + f"opp_selected={pass_summary.opposite_selected_count} detcost_selected={pass_summary.detcost_selected_count} " + f"opp_steps={pass_summary.opposite_available_steps} " + f"mean_frontier_width={pass_summary.mean_frontier_width:.2f} max_frontier_width={pass_summary.max_frontier_width} " + f"mean_beam={pass_summary.mean_beam_size:.2f} max_beam={pass_summary.max_beam_size} " + f"elapsed_seconds={pass_summary.elapsed_seconds:.6f}", + file=log_stream, + ) + + +def run_experiment(args: argparse.Namespace) -> ExperimentSummary: + circuit = stim.Circuit.from_file(args.circuit) + if circuit.num_observables != 1: + raise ValueError( + "This decoder currently supports exactly one logical observable, because it only tracks L0. " + f"The circuit has {circuit.num_observables} observables." + ) + + model = _build_decoder_model(circuit) + log_stream = sys.stderr if args.out_file == "-" else sys.stdout + + with tempfile.TemporaryDirectory() as temp_dir: + shots = _load_shots(circuit, args, temp_dir=temp_dir) + _print_run_header(circuit=circuit, model=model, args=args, num_shots=len(shots), log_stream=log_stream) + + num_errors = 0 + num_low_confidence = 0 + num_certified = 0 + num_truth_shots = 0 + num_scored_shots = 0 + total_elapsed = 0.0 + total_triggered = 0 + max_width_seen = 0 + predictions: list[bool | None] = [] + + pass_aggregates = [ + { + "count": 0, + "elapsed_seconds": 0.0, + "exact_upper_hits": 0, + "exact_upper_misses": 0, + "marginal_tighter_count": 0, + "opposite_selected_count": 0, + "detcost_selected_count": 0, + "opposite_available_steps": 0, + "mean_frontier_width_sum": 0.0, + "max_frontier_width": 0, + "mean_beam_size_sum": 0.0, + "max_beam_size": 0, + } + for _ in range(args.num_passes) + ] + + for shot_index, shot in enumerate(shots): + result = decode_beam_search_iterative( + model, + shot.det_mask, + args.beam, + num_passes=args.num_passes, + ) + predictions.append(result.predicted_logical) + + success: bool | None + if shot.actual_logical is None or result.predicted_logical is None: + success = None + else: + success = result.predicted_logical == shot.actual_logical + + if result.predicted_logical is None: + num_low_confidence += 1 + if shot.actual_logical is not None: + num_truth_shots += 1 + if success is not None: + num_scored_shots += 1 + if not success: + num_errors += 1 + if result.certified: + num_certified += 1 + + total_elapsed += result.elapsed_seconds + triggered_dets = shot.det_mask.bit_count() + total_triggered += triggered_dets + max_width_seen = max(max_width_seen, result.max_width) + + for pass_summary in result.pass_summaries: + agg = pass_aggregates[pass_summary.pass_index - 1] + agg["count"] += 1 + agg["elapsed_seconds"] += pass_summary.elapsed_seconds + agg["exact_upper_hits"] += pass_summary.exact_upper_hits + agg["exact_upper_misses"] += pass_summary.exact_upper_misses + agg["marginal_tighter_count"] += pass_summary.marginal_tighter_count + agg["opposite_selected_count"] += pass_summary.opposite_selected_count + agg["detcost_selected_count"] += pass_summary.detcost_selected_count + agg["opposite_available_steps"] += pass_summary.opposite_available_steps + agg["mean_frontier_width_sum"] += pass_summary.mean_frontier_width + agg["max_frontier_width"] = max(agg["max_frontier_width"], pass_summary.max_frontier_width) + agg["mean_beam_size_sum"] += pass_summary.mean_beam_size + agg["max_beam_size"] = max(agg["max_beam_size"], pass_summary.max_beam_size) + + shots_done = shot_index + 1 + error_rate_so_far = num_errors / num_scored_shots if num_scored_shots else 0.0 + print( + f"progress shots_done={shots_done}/{len(shots)} errors_so_far={num_errors} " + f"low_conf_so_far={num_low_confidence} scored_shots_so_far={num_scored_shots} " + f"error_rate_so_far={error_rate_so_far:.6f} elapsed_total_seconds={total_elapsed:.6f}", + file=log_stream, + ) + + if args.print_per_shot: + print( + f"shot={shot_index} triggered_detectors={triggered_dets} " + f"predicted_logical={result.predicted_logical} actual_logical={shot.actual_logical} " + f"success={success} certified={result.certified} " + f"margin={result.margin:.6e} discarded_mass={result.discarded_mass:.6e} " + f"elapsed_seconds={result.elapsed_seconds:.6f}", + file=log_stream, + ) + for pass_summary in result.pass_summaries: + _print_pass_summary(pass_summary=pass_summary, log_stream=log_stream) + + if args.out_file: + output_path, copy_to_stdout = _resolve_stdout_path_if_needed( + args.out_file, + temp_dir=temp_dir, + stem="predictions_out", + ) + prediction_data = np.zeros((len(predictions), circuit.num_observables), dtype=np.bool_) + for shot_index, predicted_logical in enumerate(predictions): + prediction_data[shot_index, 0] = bool(predicted_logical) if predicted_logical is not None else False + + if args.out_format == "ptb64" and len(prediction_data) % 64 != 0: + raise ValueError("The ptb64 format requires the number of shots to be a multiple of 64.") + + stim.write_shot_data_file( + data=prediction_data, + path=output_path, + format=args.out_format, + num_measurements=0, + num_detectors=0, + num_observables=circuit.num_observables, + ) + if copy_to_stdout: + _copy_file_to_stdout(output_path) + if num_low_confidence: + print( + f"warning: wrote {num_low_confidence} low-confidence predictions as L0=0 because Stim result " + "files can only store bits, not unknown values.", + file=log_stream, + ) + + print(f"Mean Triggered Dets: {total_triggered / max(1, len(shots)):.2f}", file=log_stream) + print(f"Max Width: {max_width_seen}", file=log_stream) + print(f"Certified Shots: {num_certified}", file=log_stream) + print(f"Low Confidence: {num_low_confidence}", file=log_stream) + print(f"Truth-Labeled Shots: {num_truth_shots}", file=log_stream) + print(f"Scored Shots: {num_scored_shots}", file=log_stream) + if num_truth_shots: + print(f"Logical Errors: {num_errors}", file=log_stream) + else: + print("Logical Errors: n/a", file=log_stream) + print(f"Total Seconds: {total_elapsed:.6f}", file=log_stream) + print(f"Mean Seconds/Shot: {total_elapsed / max(1, len(shots)):.6f}", file=log_stream) + + print("Pass Diagnostics:", file=log_stream) + for pass_index, agg in enumerate(pass_aggregates, start=1): + if agg["count"] == 0: + continue + print( + f" pass={pass_index} direction={'forward' if pass_index % 2 == 1 else 'backward'} " + f"mean_elapsed_seconds={agg['elapsed_seconds'] / agg['count']:.6f} " + f"exact_hits={agg['exact_upper_hits']} exact_misses={agg['exact_upper_misses']} " + f"marginal_tighter={agg['marginal_tighter_count']} " + f"opp_selected={agg['opposite_selected_count']} detcost_selected={agg['detcost_selected_count']} " + f"opp_steps={agg['opposite_available_steps']} " + f"mean_frontier_width={agg['mean_frontier_width_sum'] / agg['count']:.2f} " + f"max_frontier_width={agg['max_frontier_width']} " + f"mean_beam={agg['mean_beam_size_sum'] / agg['count']:.2f} " + f"max_beam={agg['max_beam_size']}", + file=log_stream, + ) + + return ExperimentSummary( + predictions=predictions, + num_certified=num_certified, + num_low_confidence=num_low_confidence, + num_errors=num_errors, + num_truth_shots=num_truth_shots, + num_scored_shots=num_scored_shots, + total_elapsed=total_elapsed, + total_triggered=total_triggered, + max_width_seen=max_width_seen, + ) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Run trellis beam decoding with detcost-style future penalties and optional iterative " + "forward/backward cross-pass frontier-mass heuristics, with Stim-compatible shot-data I/O options." + ), + allow_abbrev=False, + ) + parser.add_argument("--circuit", required=True, help="Path to the .stim circuit file.") + parser.add_argument("--beam", type=int, default=1000, help="Beam width cutoff.") + parser.add_argument( + "--num-passes", + "--num_passes", + dest="num_passes", + type=int, + default=1, + help=( + "Number of alternating beam passes to run. Pass 1 is the original forward pass, pass 2 is backward, " + "pass 3 is forward again, and so on." + ), + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=None, + help="Number of sampled shots. Defaults to 1 unless --in is provided.", + ) + parser.add_argument("--sample-seed", type=int, default=None, help="Stim sampler seed.") + parser.add_argument( + "--shot-range-begin", + type=int, + default=0, + help=( + "If both --shot-range-begin and --shot-range-end are 0, decode all available shots. " + "Otherwise only decode shots in [begin, end)." + ), + ) + parser.add_argument( + "--shot-range-end", + type=int, + default=0, + help=( + "If both --shot-range-begin and --shot-range-end are 0, decode all available shots. " + "Otherwise only decode shots in [begin, end)." + ), + ) + parser.add_argument( + "--in", + dest="in_file", + default="", + help="File to read detection events from (use - for stdin).", + ) + parser.add_argument( + "--in-format", + "--in_format", + dest="in_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file read by --in ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--in-includes-appended-observables", + "--in_includes_appended_observables", + dest="in_includes_appended_observables", + action="store_true", + help="Assume the observable flips are appended to each shot in --in.", + ) + parser.add_argument( + "--obs-in", + "--obs_in", + dest="obs_in_file", + default="", + help="File to read observable flips from (use - for stdin).", + ) + parser.add_argument( + "--obs-in-format", + "--obs_in_format", + dest="obs_in_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file read by --obs-in ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--out", + dest="out_file", + default="", + help="File to write predicted observable flips to (use - for stdout).", + ) + parser.add_argument( + "--out-format", + "--out_format", + dest="out_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file written by --out ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--print-per-shot", + action="store_true", + help="Print a detailed line per decoded shot, plus one line per decoding pass.", + ) + args = parser.parse_args() + + if args.sample_num_shots is None: + # Preserve the original script's one-shot default while still allowing + # file input without requiring --sample-num-shots 0. + args.sample_num_shots = 0 if args.in_file else 1 + + if args.beam <= 0: + raise ValueError("--beam must be positive.") + if args.num_passes <= 0: + raise ValueError("--num-passes must be positive.") + if args.sample_num_shots < 0: + raise ValueError("--sample-num-shots must be non-negative.") + if args.sample_seed is not None and args.sample_seed < 0: + raise ValueError("--sample-seed must be non-negative.") + if args.shot_range_begin < 0 or args.shot_range_end < 0: + raise ValueError("--shot-range-begin and --shot-range-end must be non-negative.") + if args.shot_range_end < args.shot_range_begin: + raise ValueError("Provided shot range must satisfy --shot-range-end >= --shot-range-begin.") + if args.in_includes_appended_observables and args.obs_in_file: + raise ValueError( + "Choose either --in-includes-appended-observables or --obs-in, not both." + ) + if args.obs_in_file and not args.in_file: + raise ValueError("Cannot load observable flips from --obs-in without also providing --in.") + if args.in_file == "-" and args.obs_in_file == "-": + raise ValueError("At most one of --in and --obs-in may read from stdin.") + + num_shot_sources = int(args.sample_num_shots > 0) + int(bool(args.in_file)) + if num_shot_sources != 1: + raise ValueError("Requires exactly one source of shots: either --sample-num-shots > 0 or --in.") + + return args + + +if __name__ == "__main__": + run_experiment(_parse_args()) diff --git a/src/py/astar/trellis_beam_opt_singleton_lp_ranked.py b/src/py/astar/trellis_beam_opt_singleton_lp_ranked.py new file mode 100644 index 0000000..634e465 --- /dev/null +++ b/src/py/astar/trellis_beam_opt_singleton_lp_ranked.py @@ -0,0 +1,1218 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import math +import shutil +import sys +import tempfile +import time +from bisect import bisect_left +from collections import OrderedDict +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable + +import numpy as np + +try: # pragma: no cover - optional at runtime in this environment. + import stim # type: ignore +except ModuleNotFoundError: # pragma: no cover - exercised when Stim is unavailable. + stim = None + +try: # pragma: no cover - optional at runtime. + from scipy.optimize import linprog # type: ignore + from scipy.sparse import csr_matrix # type: ignore +except ModuleNotFoundError: # pragma: no cover - exercised if SciPy is unavailable. + linprog = None + csr_matrix = None + + +STIM_RESULT_FORMATS = ("01", "b8", "r8", "ptb64", "hits", "dets") +STIM_RESULT_FORMATS_HELP = "/".join(STIM_RESULT_FORMATS) +INF = float("inf") + + +@dataclass(frozen=True) +class Fault: + q: float + p: float + delta_scale: float + det_mask: int + detector_ids: tuple[int, ...] + likelihood_cost: float + + +@dataclass(frozen=True) +class DecoderModel: + faults: tuple[Fault, ...] + retiring_masks: tuple[int, ...] + live_masks_after: tuple[int, ...] + plain_future_detcost: tuple[tuple[float, ...], ...] + detector_to_faults: tuple[tuple[int, ...], ...] + all_possible_dets_mask: int + max_width: int + num_detectors: int + + +@dataclass(frozen=True) +class BeamDecodeResult: + predicted_logical: bool | None + certified: bool + margin: float + discarded_mass: float + max_width: int + elapsed_seconds: float + heuristic_calls: int = 0 + cache_hits: int = 0 + lp_calls: int = 0 + lp_seconds: float = 0.0 + + +@dataclass(frozen=True) +class DecodingShot: + det_mask: int + actual_logical: bool | None + + +@dataclass(frozen=True) +class ExperimentSummary: + predictions: list[bool | None] + num_certified: int + num_low_confidence: int + num_errors: int + num_truth_shots: int + num_scored_shots: int + total_elapsed: float + total_triggered: int + max_width_seen: int + total_heuristic_calls: int + total_cache_hits: int + total_lp_calls: int + total_lp_seconds: float + + +@dataclass +class ShotSingletonLPContext: + row_index: int + detector_fault_offsets: list[int] + seen_fault_marks: list[int] + current_mark: int = 0 + + def next_mark(self) -> int: + self.current_mark += 1 + if self.current_mark >= (1 << 60): + self.seen_fault_marks[:] = [0] * len(self.seen_fault_marks) + self.current_mark = 1 + return self.current_mark + + +class UnionFind: + def __init__(self, n: int) -> None: + self.parent = list(range(n)) + self.rank = [0] * n + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra = self.find(a) + rb = self.find(b) + if ra == rb: + return + if self.rank[ra] < self.rank[rb]: + self.parent[ra] = rb + elif self.rank[ra] > self.rank[rb]: + self.parent[rb] = ra + else: + self.parent[rb] = ra + self.rank[ra] += 1 + + +class OptimalSingletonLPEvaluator: + """Evaluates the exact singleton-budget LP on a suffix of future faults. + + The dual LP is + maximize sum_d y_d + subject to sum_{d in support(e) ∩ M} y_d <= w_e for each future fault e + y_d >= 0 + where M is the current residual live-detector mismatch mask. + + Results are cached across shots by (suffix_row, mismatch_mask). Within one shot, + the suffix row advances monotonically, so per-detector pointers into the future + fault lists can be updated incrementally instead of re-bisecting each time. + """ + + def __init__( + self, + model: DecoderModel, + *, + use_cache: bool = True, + cache_max_entries: int = 0, + split_components: bool = True, + ) -> None: + self.model = model + self.use_cache = use_cache + self.cache_max_entries = cache_max_entries + self.split_components = split_components + self.cache: OrderedDict[tuple[int, int], float] = OrderedDict() + self.heuristic_calls = 0 + self.cache_hits = 0 + self.lp_calls = 0 + self.lp_seconds = 0.0 + + def clear_cache(self) -> None: + self.cache.clear() + + def begin_shot(self) -> ShotSingletonLPContext: + return ShotSingletonLPContext( + row_index=0, + detector_fault_offsets=[0] * self.model.num_detectors, + seen_fault_marks=[0] * len(self.model.faults), + ) + + def advance_past_fault(self, context: ShotSingletonLPContext, fault_index: int) -> None: + context.row_index = fault_index + 1 + target_row = context.row_index + fault = self.model.faults[fault_index] + for detector in fault.detector_ids: + future_faults = self.model.detector_to_faults[detector] + pos = context.detector_fault_offsets[detector] + while pos < len(future_faults) and future_faults[pos] < target_row: + pos += 1 + context.detector_fault_offsets[detector] = pos + + def evaluate(self, context: ShotSingletonLPContext, mismatch_mask: int) -> float: + self.heuristic_calls += 1 + + if mismatch_mask == 0: + return 0.0 + + cache_key = (context.row_index, mismatch_mask) + if self.use_cache: + cached = self.cache.get(cache_key) + if cached is not None: + self.cache_hits += 1 + self.cache.move_to_end(cache_key) + return cached + + if linprog is None or csr_matrix is None: + raise RuntimeError( + "The exact singleton-LP heuristic requires SciPy (scipy.optimize.linprog)." + ) + + mark = context.next_mark() + seen_fault_marks = context.seen_fault_marks + support_to_weight: dict[int, float] = {} + covered_mask = 0 + + for detector in _detectors_from_mask(mismatch_mask): + future_faults = self.model.detector_to_faults[detector] + start = context.detector_fault_offsets[detector] + for fault_index in future_faults[start:]: + if seen_fault_marks[fault_index] == mark: + continue + seen_fault_marks[fault_index] = mark + + fault = self.model.faults[fault_index] + support_mask = fault.det_mask & mismatch_mask + if support_mask == 0: + continue + covered_mask |= support_mask + previous = support_to_weight.get(support_mask) + if previous is None or fault.likelihood_cost < previous: + support_to_weight[support_mask] = fault.likelihood_cost + + if covered_mask != mismatch_mask: + return self._store(cache_key, INF) + + if len(support_to_weight) == 1: + only_value = next(iter(support_to_weight.values())) + return self._store(cache_key, only_value) + + if mismatch_mask.bit_count() == 1: + best = min(support_to_weight.values()) + return self._store(cache_key, best) + + start_time = time.perf_counter() + value = self._solve_support_system(support_to_weight=support_to_weight, mismatch_mask=mismatch_mask) + self.lp_seconds += time.perf_counter() - start_time + return self._store(cache_key, value) + + def _store(self, cache_key: tuple[int, int], value: float) -> float: + if self.use_cache: + self.cache[cache_key] = value + self.cache.move_to_end(cache_key) + if self.cache_max_entries > 0: + while len(self.cache) > self.cache_max_entries: + self.cache.popitem(last=False) + return value + + def _solve_support_system(self, *, support_to_weight: dict[int, float], mismatch_mask: int) -> float: + active_detectors = _detectors_from_mask(mismatch_mask) + if not active_detectors: + return 0.0 + + detector_index = {detector: i for i, detector in enumerate(active_detectors)} + + if not self.split_components: + return self._solve_component_lp( + supports=tuple(support_to_weight.items()), + detector_index=detector_index, + component_detectors=tuple(active_detectors), + ) + + uf = UnionFind(len(active_detectors)) + support_bits_cache: dict[int, tuple[int, ...]] = {} + for support_mask in support_to_weight: + bits = _detectors_from_mask(support_mask) + support_bits_cache[support_mask] = tuple(bits) + if len(bits) > 1: + base = detector_index[bits[0]] + for detector in bits[1:]: + uf.union(base, detector_index[detector]) + + component_detectors: dict[int, list[int]] = {} + for detector in active_detectors: + root = uf.find(detector_index[detector]) + component_detectors.setdefault(root, []).append(detector) + + component_supports: dict[int, list[tuple[int, float]]] = {root: [] for root in component_detectors} + for support_mask, weight in support_to_weight.items(): + bits = support_bits_cache[support_mask] + root = uf.find(detector_index[bits[0]]) + component_supports[root].append((support_mask, weight)) + + total = 0.0 + for root, detectors in component_detectors.items(): + supports = component_supports[root] + if len(detectors) == 1: + total += min(weight for _support_mask, weight in supports) + continue + total += self._solve_component_lp( + supports=tuple(supports), + detector_index=detector_index, + component_detectors=tuple(detectors), + ) + return total + + def _solve_component_lp( + self, + *, + supports: tuple[tuple[int, float], ...], + detector_index: dict[int, int], + component_detectors: tuple[int, ...], + ) -> float: + if linprog is None or csr_matrix is None: + raise RuntimeError( + "The exact singleton-LP heuristic requires SciPy (scipy.optimize.linprog)." + ) + + local_index = {detector: i for i, detector in enumerate(component_detectors)} + row_indices: list[int] = [] + col_indices: list[int] = [] + data: list[float] = [] + rhs: list[float] = [] + + for row, (support_mask, weight) in enumerate(supports): + rhs.append(weight) + pending = support_mask + while pending: + low_bit = pending & -pending + detector = low_bit.bit_length() - 1 + pending ^= low_bit + col_indices.append(local_index[detector]) + row_indices.append(row) + data.append(1.0) + + a_ub = csr_matrix( + (data, (row_indices, col_indices)), + shape=(len(supports), len(component_detectors)), + dtype=np.float64, + ) + self.lp_calls += 1 + result = linprog( + c=-np.ones(len(component_detectors), dtype=np.float64), + A_ub=a_ub, + b_ub=np.array(rhs, dtype=np.float64), + bounds=[(0.0, None)] * len(component_detectors), + method="highs", + ) + if result.status == 0: + return max(0.0, float(-result.fun)) + if result.status in {2, 3}: + return INF + raise RuntimeError(f"linprog failed with status={result.status}: {result.message}") + + +def _require_stim() -> None: + if stim is None: + raise RuntimeError( + "This script requires stim for CLI operation. Install stim, or import the module and build models manually." + ) + + +def _likelihood_cost(probability: float) -> float: + if probability <= 0.0: + return math.inf + if probability >= 1.0: + return 0.0 + return -math.log(probability / (1.0 - probability)) + + +def _iter_mask_bits(mask: int) -> Iterable[int]: + while mask: + low_bit = mask & -mask + yield low_bit.bit_length() - 1 + mask ^= low_bit + + +def _detectors_from_mask(mask: int) -> list[int]: + return list(_iter_mask_bits(mask)) + + +def _mask_from_bool_row(row: np.ndarray) -> int: + mask = 0 + for index in np.flatnonzero(row): + mask |= 1 << int(index) + return mask + + +def _future_detcost_by_detector(faults: tuple[Fault, ...], num_detectors: int) -> tuple[tuple[float, ...], ...]: + future_detcost: list[list[float]] = [[math.inf] * num_detectors for _ in range(len(faults) + 1)] + next_row = future_detcost[-1] + for fault_index in range(len(faults) - 1, -1, -1): + row = next_row.copy() + fault = faults[fault_index] + det_count = len(fault.detector_ids) + if det_count: + ecost = fault.likelihood_cost / det_count + for det_id in fault.detector_ids: + if ecost < row[det_id]: + row[det_id] = ecost + future_detcost[fault_index] = row + next_row = row + return tuple(tuple(row) for row in future_detcost) + + +def _build_decoder_model(circuit: stim.Circuit) -> DecoderModel: + _require_stim() + dem = circuit.detector_error_model(decompose_errors=False).flattened() + + faults: list[Fault] = [] + all_possible_dets_mask = 0 + last_seen_index: dict[int, int] = {} + detector_to_faults_lists: list[list[int]] = [[] for _ in range(circuit.num_detectors)] + + for inst in dem: + if inst.type != "error": + continue + + p = float(inst.args_copy()[0]) + det_mask = 0 + flip_l0 = 0 + for target in inst.targets_copy(): + if target.is_separator(): + continue + if target.is_relative_detector_id(): + det_mask ^= 1 << target.val + elif target.is_logical_observable_id() and target.val == 0: + flip_l0 ^= 1 + + detector_ids = tuple(_detectors_from_mask(det_mask)) + fault = Fault( + q=1.0 - p, + p=p, + delta_scale=(-p if flip_l0 else p), + det_mask=det_mask, + detector_ids=detector_ids, + likelihood_cost=_likelihood_cost(p), + ) + faults.append(fault) + all_possible_dets_mask |= det_mask + fault_index = len(faults) - 1 + for det_id in detector_ids: + last_seen_index[det_id] = fault_index + detector_to_faults_lists[det_id].append(fault_index) + + retiring_masks = [0] * len(faults) + for det_id, index in last_seen_index.items(): + retiring_masks[index] |= 1 << det_id + + live_masks_after = [0] * (len(faults) + 1) + active_mask = 0 + max_width = 0 + for i, fault in enumerate(faults): + active_mask |= fault.det_mask + max_width = max(max_width, active_mask.bit_count()) + active_mask &= ~retiring_masks[i] + live_masks_after[i + 1] = active_mask + + frozen_faults = tuple(faults) + return DecoderModel( + faults=frozen_faults, + retiring_masks=tuple(retiring_masks), + live_masks_after=tuple(live_masks_after), + plain_future_detcost=_future_detcost_by_detector(frozen_faults, circuit.num_detectors), + detector_to_faults=tuple(tuple(v) for v in detector_to_faults_lists), + all_possible_dets_mask=all_possible_dets_mask, + max_width=max_width, + num_detectors=circuit.num_detectors, + ) + + +def _detcost_penalty(mismatch_mask: int, future_detcost: tuple[float, ...]) -> float: + total = 0.0 + pending = mismatch_mask + + while pending: + low_bit = pending & -pending + detector = low_bit.bit_length() - 1 + pending ^= low_bit + + best = future_detcost[detector] + if best == math.inf: + return math.inf + total += best + + return total + + +def _as_bool_2d(data: np.ndarray, *, expected_cols: int, description: str) -> np.ndarray: + arr = np.asarray(data) + if arr.ndim != 2: + raise ValueError(f"Expected {description} to be a 2D array but got shape {arr.shape!r}.") + if arr.shape[1] != expected_cols: + raise ValueError( + f"Expected {description} to have {expected_cols} columns but got {arr.shape[1]}." + ) + if arr.dtype != np.bool_: + arr = arr.astype(np.bool_, copy=False) + return arr + + +def _sample_shot_arrays( + circuit: stim.Circuit, + *, + shots: int, + seed: int | None, +) -> tuple[np.ndarray, np.ndarray]: + _require_stim() + sampler = circuit.compile_detector_sampler(seed=seed) + dets, obs = sampler.sample(shots=shots, separate_observables=True) + return ( + _as_bool_2d(dets, expected_cols=circuit.num_detectors, description="sampled detector data"), + _as_bool_2d(obs, expected_cols=circuit.num_observables, description="sampled observable data"), + ) + + +def _read_detector_shot_arrays( + *, + path: str, + fmt: str, + num_detectors: int, + num_observables: int, +) -> tuple[np.ndarray, np.ndarray | None]: + _require_stim() + common_kwargs = dict( + path=path, + format=fmt, + bit_packed=False, + num_measurements=0, + num_detectors=num_detectors, + num_observables=num_observables, + ) + + if num_observables: + try: + dets, obs = stim.read_shot_data_file(**common_kwargs, separate_observables=True) + return ( + _as_bool_2d(dets, expected_cols=num_detectors, description="input detector data"), + _as_bool_2d(obs, expected_cols=num_observables, description="appended observable data"), + ) + except TypeError: + flat = stim.read_shot_data_file(**common_kwargs) + flat = _as_bool_2d( + flat, + expected_cols=num_detectors + num_observables, + description="combined detector/observable input data", + ) + return flat[:, :num_detectors], flat[:, num_detectors:] + + flat = stim.read_shot_data_file(**common_kwargs) + return _as_bool_2d(flat, expected_cols=num_detectors, description="input detector data"), None + + +def _read_observable_shot_array(*, path: str, fmt: str, num_observables: int) -> np.ndarray: + _require_stim() + obs = stim.read_shot_data_file( + path=path, + format=fmt, + bit_packed=False, + num_measurements=0, + num_detectors=0, + num_observables=num_observables, + ) + return _as_bool_2d(obs, expected_cols=num_observables, description="observable input data") + + +def _apply_shot_range( + dets: np.ndarray, + obs: np.ndarray | None, + *, + shot_range_begin: int, + shot_range_end: int, +) -> tuple[np.ndarray, np.ndarray | None]: + if not (shot_range_begin or shot_range_end): + return dets, obs + + if shot_range_end < shot_range_begin: + raise ValueError("Provided shot range must satisfy --shot-range-end >= --shot-range-begin.") + if shot_range_end > len(dets): + raise ValueError( + f"Shot range end {shot_range_end} is past the end of the shot data (size {len(dets)})." + ) + + dets = dets[shot_range_begin:shot_range_end] + if obs is not None: + obs = obs[shot_range_begin:shot_range_end] + return dets, obs + + +def _shots_from_arrays(dets: np.ndarray, obs: np.ndarray | None) -> list[DecodingShot]: + shots: list[DecodingShot] = [] + for shot_index in range(dets.shape[0]): + actual_logical = None if obs is None else bool(obs[shot_index, 0]) + shots.append( + DecodingShot( + det_mask=_mask_from_bool_row(dets[shot_index]), + actual_logical=actual_logical, + ) + ) + return shots + + +def _resolve_stdin_path_if_needed(path: str, *, temp_dir: str, stem: str) -> str: + if path != "-": + return path + temp_path = str(Path(temp_dir) / f"{stem}.bin") + with open(temp_path, "wb") as f: + f.write(sys.stdin.buffer.read()) + return temp_path + + +def _resolve_stdout_path_if_needed(path: str, *, temp_dir: str, stem: str) -> tuple[str, bool]: + if path != "-": + return path, False + return str(Path(temp_dir) / f"{stem}.bin"), True + + +def _copy_file_to_stdout(path: str) -> None: + sys.stdout.flush() + with open(path, "rb") as f: + shutil.copyfileobj(f, sys.stdout.buffer) + sys.stdout.buffer.flush() + + +def _load_shots( + circuit: stim.Circuit, + args: argparse.Namespace, + *, + temp_dir: str, +) -> list[DecodingShot]: + if args.in_file: + in_path = _resolve_stdin_path_if_needed(args.in_file, temp_dir=temp_dir, stem="shots_in") + appended_obs_count = circuit.num_observables if args.in_includes_appended_observables else 0 + dets, obs = _read_detector_shot_arrays( + path=in_path, + fmt=args.in_format, + num_detectors=circuit.num_detectors, + num_observables=appended_obs_count, + ) + + if args.obs_in_file: + obs_in_path = _resolve_stdin_path_if_needed(args.obs_in_file, temp_dir=temp_dir, stem="obs_in") + obs = _read_observable_shot_array( + path=obs_in_path, + fmt=args.obs_in_format, + num_observables=circuit.num_observables, + ) + if len(obs) != len(dets): + raise ValueError("Observable input ended before, or after, the detector shot data.") + else: + dets, obs = _sample_shot_arrays(circuit, shots=args.sample_num_shots, seed=args.sample_seed) + + dets, obs = _apply_shot_range( + dets, + obs, + shot_range_begin=args.shot_range_begin, + shot_range_end=args.shot_range_end, + ) + return _shots_from_arrays(dets, obs) + + +def decode_beam_search_singleton_lp_ranked( + model: DecoderModel, + actual_dets_mask: int, + L: int, + *, + heuristic: str, + evaluator: OptimalSingletonLPEvaluator | None = None, +) -> BeamDecodeResult: + start_time = time.perf_counter() + + if heuristic not in {"opt_singleton_lp", "plain_detcost"}: + raise ValueError(f"Unsupported heuristic {heuristic!r}.") + + if (actual_dets_mask & ~model.all_possible_dets_mask) != 0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=0.0, + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + if heuristic == "opt_singleton_lp": + if evaluator is None: + evaluator = OptimalSingletonLPEvaluator(model) + context = evaluator.begin_shot() + start_heuristic_calls = evaluator.heuristic_calls + start_cache_hits = evaluator.cache_hits + start_lp_calls = evaluator.lp_calls + start_lp_seconds = evaluator.lp_seconds + else: + context = None + start_heuristic_calls = 0 + start_cache_hits = 0 + start_lp_calls = 0 + start_lp_seconds = 0.0 + + beam = [(0, 1.0, 1.0)] + discarded_mass = 0.0 + + for i, fault in enumerate(model.faults): + collapsed_probs: dict[int, list[float]] = {} + total_mass = 0.0 + retiring_mask = model.retiring_masks[i] + + if retiring_mask == 0: + for state, total, delta in beam: + absent_total = total * fault.q + absent_delta = delta * fault.q + total_mass += absent_total + entry = collapsed_probs.get(state) + if entry is None: + collapsed_probs[state] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + toggled = state ^ fault.det_mask + present_total = total * fault.p + present_delta = delta * fault.delta_scale + total_mass += present_total + entry = collapsed_probs.get(toggled) + if entry is None: + collapsed_probs[toggled] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + else: + expected_bits = actual_dets_mask & retiring_mask + keep_mask = ~retiring_mask + for state, total, delta in beam: + absent_total = total * fault.q + absent_delta = delta * fault.q + if (state & retiring_mask) == expected_bits: + shrunk = state & keep_mask + total_mass += absent_total + entry = collapsed_probs.get(shrunk) + if entry is None: + collapsed_probs[shrunk] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + toggled = state ^ fault.det_mask + present_total = total * fault.p + present_delta = delta * fault.delta_scale + if (toggled & retiring_mask) == expected_bits: + shrunk = toggled & keep_mask + total_mass += present_total + entry = collapsed_probs.get(shrunk) + if entry is None: + collapsed_probs[shrunk] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + + if total_mass == 0.0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=discarded_mass, + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + heuristic_calls=(0 if evaluator is None else evaluator.heuristic_calls - start_heuristic_calls), + cache_hits=(0 if evaluator is None else evaluator.cache_hits - start_cache_hits), + lp_calls=(0 if evaluator is None else evaluator.lp_calls - start_lp_calls), + lp_seconds=(0.0 if evaluator is None else evaluator.lp_seconds - start_lp_seconds), + ) + + live_target_mask = actual_dets_mask & model.live_masks_after[i + 1] + if context is not None: + evaluator.advance_past_fault(context, i) + + ranked_states: list[tuple[float, float, int, float]] = [] + for state, (total, delta) in collapsed_probs.items(): + mismatch_mask = state ^ live_target_mask + if heuristic == "plain_detcost": + penalty = _detcost_penalty( + mismatch_mask=mismatch_mask, + future_detcost=model.plain_future_detcost[i + 1], + ) + else: + assert evaluator is not None and context is not None + penalty = evaluator.evaluate(context, mismatch_mask) + if penalty == math.inf: + rank_score = -math.inf + else: + rank_score = math.log(total) - penalty + ranked_states.append((rank_score, total, state, delta)) + + dropped_mass = 0.0 + if len(ranked_states) > L: + ranked_states.sort(reverse=True) + kept = ranked_states[:L] + beam = [(state, total, delta) for _, total, state, delta in kept] + kept_mass = sum(total for _, total, _, _ in kept) + dropped_mass = total_mass - kept_mass + else: + beam = [(state, total, delta) for _, total, state, delta in ranked_states] + + inv_total_mass = 1.0 / total_mass + discarded_mass = (discarded_mass + dropped_mass) * inv_total_mass + beam = [ + (state, total * inv_total_mass, delta * inv_total_mass) + for state, total, delta in beam + ] + + _, _, final_delta = next((entry for entry in beam if entry[0] == 0), (0, 0.0, 0.0)) + margin = abs(final_delta) + certified = margin > discarded_mass + + result = BeamDecodeResult( + predicted_logical=None if final_delta == 0.0 else (final_delta < 0.0), + certified=(False if final_delta == 0.0 else certified), + margin=margin, + discarded_mass=discarded_mass, + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + heuristic_calls=(0 if evaluator is None else evaluator.heuristic_calls - start_heuristic_calls), + cache_hits=(0 if evaluator is None else evaluator.cache_hits - start_cache_hits), + lp_calls=(0 if evaluator is None else evaluator.lp_calls - start_lp_calls), + lp_seconds=(0.0 if evaluator is None else evaluator.lp_seconds - start_lp_seconds), + ) + return result + + +def _print_run_header( + *, + circuit: stim.Circuit, + args: argparse.Namespace, + num_shots: int, + log_stream, + evaluator: OptimalSingletonLPEvaluator | None, +) -> None: + print(f"Running on circuit {args.circuit}", file=log_stream) + print(f"Total Detectors: {circuit.num_detectors}", file=log_stream) + print(f"Total Observables: {circuit.num_observables}", file=log_stream) + print(f"Heuristic: {args.heuristic}", file=log_stream) + if args.heuristic == "opt_singleton_lp": + print( + f"Singleton LP Cache: {'on' if not args.no_singleton_lp_cache else 'off'}", + file=log_stream, + ) + if evaluator is not None and evaluator.cache_max_entries > 0: + print(f"Cache Max Entries: {evaluator.cache_max_entries}", file=log_stream) + else: + print("Cache Max Entries: unlimited", file=log_stream) + print( + f"Component Splitting: {'on' if not args.no_singleton_lp_component_splitting else 'off'}", + file=log_stream, + ) + if args.in_file: + print(f"Shot Input: {args.in_file}", file=log_stream) + print(f"Shot Input Format: {args.in_format}", file=log_stream) + if args.in_includes_appended_observables: + print("Observable Input: appended to --in", file=log_stream) + elif args.obs_in_file: + print(f"Observable Input: {args.obs_in_file}", file=log_stream) + print(f"Observable Format: {args.obs_in_format}", file=log_stream) + else: + print("Observable Input: none", file=log_stream) + else: + print(f"Sample Seed: {args.sample_seed}", file=log_stream) + print(f"Requested Shots: {args.sample_num_shots}", file=log_stream) + if args.shot_range_begin or args.shot_range_end: + print( + f"Shot Range: [{args.shot_range_begin}, {args.shot_range_end})", + file=log_stream, + ) + print(f"Num Shots: {num_shots}", file=log_stream) + + +def run_experiment(args: argparse.Namespace) -> ExperimentSummary: + _require_stim() + circuit = stim.Circuit.from_file(args.circuit) + if circuit.num_observables != 1: + raise ValueError( + "This decoder currently supports exactly one logical observable, because it only tracks L0. " + f"The circuit has {circuit.num_observables} observables." + ) + + model = _build_decoder_model(circuit) + evaluator = None + if args.heuristic == "opt_singleton_lp": + evaluator = OptimalSingletonLPEvaluator( + model, + use_cache=not args.no_singleton_lp_cache, + cache_max_entries=args.singleton_lp_cache_max_entries, + split_components=not args.no_singleton_lp_component_splitting, + ) + log_stream = sys.stderr if args.out_file == "-" else sys.stdout + + with tempfile.TemporaryDirectory() as temp_dir: + shots = _load_shots(circuit, args, temp_dir=temp_dir) + _print_run_header( + circuit=circuit, + args=args, + num_shots=len(shots), + log_stream=log_stream, + evaluator=evaluator, + ) + + num_errors = 0 + num_low_confidence = 0 + num_certified = 0 + num_truth_shots = 0 + num_scored_shots = 0 + total_elapsed = 0.0 + total_triggered = 0 + max_width_seen = 0 + total_heuristic_calls = 0 + total_cache_hits = 0 + total_lp_calls = 0 + total_lp_seconds = 0.0 + predictions: list[bool | None] = [] + + for shot_index, shot in enumerate(shots): + if args.singleton_lp_clear_cache_between_shots and evaluator is not None: + evaluator.clear_cache() + + result = decode_beam_search_singleton_lp_ranked( + model, + shot.det_mask, + args.beam, + heuristic=args.heuristic, + evaluator=evaluator, + ) + predictions.append(result.predicted_logical) + + success: bool | None + if shot.actual_logical is None or result.predicted_logical is None: + success = None + else: + success = result.predicted_logical == shot.actual_logical + + if result.predicted_logical is None: + num_low_confidence += 1 + if shot.actual_logical is not None: + num_truth_shots += 1 + if success is not None: + num_scored_shots += 1 + if not success: + num_errors += 1 + if result.certified: + num_certified += 1 + + total_elapsed += result.elapsed_seconds + total_heuristic_calls += result.heuristic_calls + total_cache_hits += result.cache_hits + total_lp_calls += result.lp_calls + total_lp_seconds += result.lp_seconds + triggered_dets = shot.det_mask.bit_count() + total_triggered += triggered_dets + max_width_seen = max(max_width_seen, result.max_width) + + shots_done = shot_index + 1 + error_rate_so_far = num_errors / num_scored_shots if num_scored_shots else 0.0 + progress = ( + f"progress shots_done={shots_done}/{len(shots)} errors_so_far={num_errors} " + f"low_conf_so_far={num_low_confidence} scored_shots_so_far={num_scored_shots} " + f"error_rate_so_far={error_rate_so_far:.6f} elapsed_total_seconds={total_elapsed:.6f}" + ) + if args.print_heuristic_stats: + progress += ( + f" heuristic_calls_so_far={total_heuristic_calls} cache_hits_so_far={total_cache_hits} " + f"lp_calls_so_far={total_lp_calls} lp_seconds_so_far={total_lp_seconds:.6f}" + ) + print(progress, file=log_stream) + + if args.print_per_shot: + line = ( + f"shot={shot_index} triggered_detectors={triggered_dets} " + f"predicted_logical={result.predicted_logical} actual_logical={shot.actual_logical} " + f"success={success} certified={result.certified} " + f"margin={result.margin:.6e} discarded_mass={result.discarded_mass:.6e} " + f"elapsed_seconds={result.elapsed_seconds:.6f}" + ) + if args.print_heuristic_stats: + line += ( + f" heuristic_calls={result.heuristic_calls} cache_hits={result.cache_hits} " + f"lp_calls={result.lp_calls} lp_seconds={result.lp_seconds:.6f}" + ) + print(line, file=log_stream) + + if args.out_file: + output_path, copy_to_stdout = _resolve_stdout_path_if_needed( + args.out_file, + temp_dir=temp_dir, + stem="predictions_out", + ) + prediction_data = np.zeros((len(predictions), circuit.num_observables), dtype=np.bool_) + for shot_index, predicted_logical in enumerate(predictions): + prediction_data[shot_index, 0] = bool(predicted_logical) if predicted_logical is not None else False + + if args.out_format == "ptb64" and len(prediction_data) % 64 != 0: + raise ValueError("The ptb64 format requires the number of shots to be a multiple of 64.") + + stim.write_shot_data_file( + data=prediction_data, + path=output_path, + format=args.out_format, + num_measurements=0, + num_detectors=0, + num_observables=circuit.num_observables, + ) + if copy_to_stdout: + _copy_file_to_stdout(output_path) + if num_low_confidence: + print( + f"warning: wrote {num_low_confidence} low-confidence predictions as L0=0 because Stim result " + "files can only store bits, not unknown values.", + file=log_stream, + ) + + print(f"Beam: {args.beam}", file=log_stream) + print(f"Mean Triggered Dets: {total_triggered / max(1, len(shots)):.2f}", file=log_stream) + print(f"Max Width: {max_width_seen}", file=log_stream) + print(f"Certified Shots: {num_certified}", file=log_stream) + print(f"Low Confidence: {num_low_confidence}", file=log_stream) + print(f"Truth-Labeled Shots: {num_truth_shots}", file=log_stream) + print(f"Scored Shots: {num_scored_shots}", file=log_stream) + if num_truth_shots: + print(f"Logical Errors: {num_errors}", file=log_stream) + else: + print("Logical Errors: n/a", file=log_stream) + print(f"Total Seconds: {total_elapsed:.6f}", file=log_stream) + print(f"Mean Seconds/Shot: {total_elapsed / max(1, len(shots)):.6f}", file=log_stream) + if args.print_heuristic_stats: + print(f"Heuristic Calls: {total_heuristic_calls}", file=log_stream) + print(f"LP Cache Hits: {total_cache_hits}", file=log_stream) + print(f"LP Solves: {total_lp_calls}", file=log_stream) + print(f"LP Seconds: {total_lp_seconds:.6f}", file=log_stream) + if evaluator is not None: + print(f"Cache Entries: {len(evaluator.cache)}", file=log_stream) + + return ExperimentSummary( + predictions=predictions, + num_certified=num_certified, + num_low_confidence=num_low_confidence, + num_errors=num_errors, + num_truth_shots=num_truth_shots, + num_scored_shots=num_scored_shots, + total_elapsed=total_elapsed, + total_triggered=total_triggered, + max_width_seen=max_width_seen, + total_heuristic_calls=total_heuristic_calls, + total_cache_hits=total_cache_hits, + total_lp_calls=total_lp_calls, + total_lp_seconds=total_lp_seconds, + ) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Run trellis beam decoding ranked by mass minus an exact optimal singleton-LP future penalty, " + "with Stim-compatible shot-data I/O options." + ), + allow_abbrev=False, + ) + parser.add_argument("--circuit", required=True, help="Path to the .stim circuit file.") + parser.add_argument("--beam", type=int, default=1000, help="Beam width cutoff.") + parser.add_argument( + "--heuristic", + choices=("opt_singleton_lp", "plain_detcost"), + default="opt_singleton_lp", + help=( + "Future-penalty heuristic used for ranking beam states. " + "'opt_singleton_lp' uses the exact optimal singleton LP; 'plain_detcost' recovers the original decoder." + ), + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=None, + help="Number of sampled shots. Defaults to 1 unless --in is provided.", + ) + parser.add_argument("--sample-seed", type=int, default=None, help="Stim sampler seed.") + parser.add_argument( + "--shot-range-begin", + type=int, + default=0, + help=( + "If both --shot-range-begin and --shot-range-end are 0, decode all available shots. " + "Otherwise only decode shots in [begin, end)." + ), + ) + parser.add_argument( + "--shot-range-end", + type=int, + default=0, + help=( + "If both --shot-range-begin and --shot-range-end are 0, decode all available shots. " + "Otherwise only decode shots in [begin, end)." + ), + ) + parser.add_argument( + "--in", + dest="in_file", + default="", + help="File to read detection events from (use - for stdin).", + ) + parser.add_argument( + "--in-format", + "--in_format", + dest="in_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file read by --in ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--in-includes-appended-observables", + "--in_includes_appended_observables", + dest="in_includes_appended_observables", + action="store_true", + help="Assume the observable flips are appended to each shot in --in.", + ) + parser.add_argument( + "--obs-in", + "--obs_in", + dest="obs_in_file", + default="", + help="File to read observable flips from (use - for stdin).", + ) + parser.add_argument( + "--obs-in-format", + "--obs_in_format", + dest="obs_in_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file read by --obs-in ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--out", + dest="out_file", + default="", + help="File to write predicted observable flips to (use - for stdout).", + ) + parser.add_argument( + "--out-format", + "--out_format", + dest="out_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file written by --out ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--no-singleton-lp-cache", + action="store_true", + help="Disable reuse of exact singleton-LP values across shots.", + ) + parser.add_argument( + "--singleton-lp-cache-max-entries", + type=int, + default=0, + help="Optional LRU cap on cached exact singleton-LP states. 0 means unlimited.", + ) + parser.add_argument( + "--singleton-lp-clear-cache-between-shots", + action="store_true", + help="Clear the exact singleton-LP cache before every shot.", + ) + parser.add_argument( + "--no-singleton-lp-component-splitting", + action="store_true", + help="Disable decomposition of the singleton LP into disconnected detector components.", + ) + parser.add_argument( + "--print-heuristic-stats", + action="store_true", + help="Print exact singleton-LP and cache statistics during the run.", + ) + parser.add_argument( + "--print-per-shot", + action="store_true", + help="Print a detailed line per decoded shot.", + ) + args = parser.parse_args() + + if args.sample_num_shots is None: + args.sample_num_shots = 0 if args.in_file else 1 + + if args.beam <= 0: + raise ValueError("--beam must be positive.") + if args.sample_num_shots < 0: + raise ValueError("--sample-num-shots must be non-negative.") + if args.sample_seed is not None and args.sample_seed < 0: + raise ValueError("--sample-seed must be non-negative.") + if args.shot_range_begin < 0 or args.shot_range_end < 0: + raise ValueError("--shot-range-begin and --shot-range-end must be non-negative.") + if args.shot_range_end < args.shot_range_begin: + raise ValueError("Provided shot range must satisfy --shot-range-end >= --shot-range-begin.") + if args.in_includes_appended_observables and args.obs_in_file: + raise ValueError( + "Choose either --in-includes-appended-observables or --obs-in, not both." + ) + if args.obs_in_file and not args.in_file: + raise ValueError("Cannot load observable flips from --obs-in without also providing --in.") + if args.in_file == "-" and args.obs_in_file == "-": + raise ValueError("At most one of --in and --obs-in may read from stdin.") + if args.singleton_lp_cache_max_entries < 0: + raise ValueError("--singleton-lp-cache-max-entries must be non-negative.") + if args.heuristic == "plain_detcost" and ( + args.no_singleton_lp_cache + or args.singleton_lp_cache_max_entries + or args.singleton_lp_clear_cache_between_shots + or args.no_singleton_lp_component_splitting + ): + # Allowed but pointless; keep the CLI permissive. + pass + + num_shot_sources = int(args.sample_num_shots > 0) + int(bool(args.in_file)) + if num_shot_sources != 1: + raise ValueError("Requires exactly one source of shots: either --sample-num-shots > 0 or --in.") + + return args + + +if __name__ == "__main__": + run_experiment(_parse_args()) diff --git a/src/py/astar/trellis_beam_optimized_suspicious.py b/src/py/astar/trellis_beam_optimized_suspicious.py new file mode 100644 index 0000000..c1f10e7 --- /dev/null +++ b/src/py/astar/trellis_beam_optimized_suspicious.py @@ -0,0 +1,237 @@ +import argparse +import time +from dataclasses import dataclass + +import stim + + +@dataclass(frozen=True) +class BeamDecodeResult: + predicted_logical: bool | None + certified: bool + margin: float + discarded_mass: float + max_width: int + elapsed_seconds: float + + +def decode_beam_search(circuit: stim.Circuit, actual_dets: set[int], L: int) -> BeamDecodeResult: + """ + Decodes a syndrome using a dynamic programming sweep with a Top-L beam cutoff. + """ + start_time = time.perf_counter() + + dem = circuit.detector_error_model(decompose_errors=False).flattened() + + faults = [] + all_possible_dets_mask = 0 + + for inst in dem: + if inst.type != "error": + continue + + p = inst.args_copy()[0] + det_mask = 0 + flip_l0 = 0 + + for t in inst.targets_copy(): + if t.is_separator(): + continue + if t.is_relative_detector_id(): + det_mask ^= (1 << t.val) + elif t.is_logical_observable_id() and t.val == 0: + flip_l0 ^= 1 + + q = 1.0 - p + delta_scale = -p if flip_l0 else p + faults.append((q, p, delta_scale, det_mask)) + all_possible_dets_mask |= det_mask + + actual_dets_mask = 0 + for d in actual_dets: + actual_dets_mask ^= (1 << d) + + if (actual_dets_mask & ~all_possible_dets_mask) != 0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=0.0, + max_width=0, + elapsed_seconds=time.perf_counter() - start_time, + ) + + retiring_masks = [0] * len(faults) + last_seen_index = {} + + for idx, (_, _, _, det_mask) in enumerate(faults): + temp = det_mask + d_id = 0 + while temp > 0: + if temp & 1: + last_seen_index[d_id] = idx + temp >>= 1 + d_id += 1 + + for d_id, idx in last_seen_index.items(): + retiring_masks[idx] |= (1 << d_id) + + active_mask = 0 + max_width = 0 + for i, (_, _, _, det_mask) in enumerate(faults): + active_mask |= det_mask + max_width = max(max_width, active_mask.bit_count()) + active_mask &= ~retiring_masks[i] + + beam = [(0, 1.0, 1.0)] + discarded_mass = 0.0 + + for i, (q, p, delta_scale, det_mask) in enumerate(faults): + collapsed_probs: dict[int, list[float]] = {} + total_mass = 0.0 + retiring_mask = retiring_masks[i] + + if retiring_mask == 0: + for s, total, delta in beam: + absent_total = total * q + absent_delta = delta * q + total_mass += absent_total + entry = collapsed_probs.get(s) + if entry is None: + collapsed_probs[s] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + t = s ^ det_mask + present_total = total * p + present_delta = delta * delta_scale + total_mass += present_total + entry = collapsed_probs.get(t) + if entry is None: + collapsed_probs[t] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + else: + expected_bits = actual_dets_mask & retiring_mask + keep_mask = ~retiring_mask + for s, total, delta in beam: + absent_total = total * q + absent_delta = delta * q + if (s & retiring_mask) == expected_bits: + shrunk_s = s & keep_mask + total_mass += absent_total + entry = collapsed_probs.get(shrunk_s) + if entry is None: + collapsed_probs[shrunk_s] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + t = s ^ det_mask + present_total = total * p + present_delta = delta * delta_scale + if (t & retiring_mask) == expected_bits: + shrunk_t = t & keep_mask + total_mass += present_total + entry = collapsed_probs.get(shrunk_t) + if entry is None: + collapsed_probs[shrunk_t] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + + ranked_states = [(total, state, delta) for state, (total, delta) in collapsed_probs.items()] + if total_mass == 0.0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=discarded_mass, + max_width=max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + dropped_mass = 0.0 + if len(ranked_states) > L: + ranked_states.sort(reverse=True) + kept = ranked_states[:L] + beam = [(state, total, delta) for total, state, delta in kept] + kept_mass = sum(total for total, _, _ in kept) + dropped_mass = total_mass - kept_mass + else: + beam = [(state, total, delta) for total, state, delta in ranked_states] + + inv_total_mass = 1.0 / total_mass + discarded_mass = (discarded_mass + dropped_mass) * inv_total_mass + beam = [ + (state, total * inv_total_mass, delta * inv_total_mass) + for state, total, delta in beam + ] + + _, _, final_delta = next((entry for entry in beam if entry[0] == 0), (0, 0.0, 0.0)) + margin = abs(final_delta) + certified = margin > discarded_mass + + if final_delta == 0.0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=margin, + discarded_mass=discarded_mass, + max_width=max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + return BeamDecodeResult( + predicted_logical=final_delta < 0.0, + certified=certified, + margin=margin, + discarded_mass=discarded_mass, + max_width=max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + +def run_experiment(circuit_fname: str, L: int, seed=None): + print(f"Running on circuit {circuit_fname}") + + circuit = stim.Circuit.from_file(circuit_fname) + + sampler = circuit.compile_detector_sampler(seed=seed) + syndromes, logicals = sampler.sample(shots=1, separate_observables=True) + + actual_dets = set(i for i, triggered in enumerate(syndromes[0]) if triggered) + actual_logical = logicals[0][0] + + result = decode_beam_search(circuit, actual_dets, L) + + print(f"Total Detectors: {circuit.num_detectors}") + print(f"Seed: {seed}") + print(f"Triggered Detectors: {len(actual_dets)}") + print(f"Width: {result.max_width}") + print(f"Predicted Logical: {result.predicted_logical}") + print(f"Actual Logical: {bool(actual_logical)}") + print(f"Certified: {result.certified}") + print(f"Margin: {result.margin:.6e}") + print(f"Discarded Mass: {result.discarded_mass:.6e}") + print(f"Elapsed Seconds: {result.elapsed_seconds:.6f}") + + if result.predicted_logical is None: + print("Result: DECODE FAILED (Tie or Beam too narrow)") + else: + print(f"Result: {'SUCCESS' if result.predicted_logical == actual_logical else 'LOGICAL ERROR'}") + print() + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run one-shot trellis beam decoding on a Stim circuit.") + parser.add_argument("--circuit", required=True, help="Path to the .stim circuit file.") + parser.add_argument("--beam", type=int, default=1000, help="Beam width cutoff.") + parser.add_argument("--seed", type=int, default=None, help="Sampler seed.") + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_args() + run_experiment(args.circuit, L=args.beam, seed=args.seed) diff --git a/src/py/tesseract_test.py b/src/py/tesseract_test.py index bb62bea..337bf91 100644 --- a/src/py/tesseract_test.py +++ b/src/py/tesseract_test.py @@ -195,6 +195,20 @@ def test_create_tesseract_decoder(): assert decoder.cost_from_errors([1]) == pytest.approx(0.5108256237659907) +def test_tesseract_priority_queue_counters_track_search(): + config = tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL) + decoder = tesseract_decoder.tesseract.TesseractDecoder(config) + + decoder.decode_to_errors(np.array([True, False], dtype=bool)) + assert decoder.num_pq_pushed >= 1 + assert decoder.num_pq_popped >= 1 + assert decoder.num_pq_pushed >= decoder.num_pq_popped + + decoder.decode_to_errors(np.array([False, False], dtype=bool)) + assert decoder.num_pq_pushed == 1 + assert decoder.num_pq_popped == 1 + + def test_tesseract_compile_decoder(): shared_test_compile_decoder( tesseract_decoder.tesseract.TesseractConfig, diff --git a/src/tesseract.cc b/src/tesseract.cc index cc92d28..c6ce011 100644 --- a/src/tesseract.cc +++ b/src/tesseract.cc @@ -95,7 +95,9 @@ double TesseractDecoder::get_detcost( for (int ei : d2e[d]) { ec = error_costs[ei]; - if (ec.likelihood_cost * min_det_cost_det_count >= min_cost * errors[ei].symptom.detectors.size()) break; + if (ec.likelihood_cost * min_det_cost_det_count >= + min_cost * errors[ei].symptom.detectors.size()) + break; dct = detector_cost_tuples[ei]; if (!dct.error_blocked) { @@ -283,6 +285,8 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, size_t detector_order, size_t detector_beam) { predicted_errors_buffer.clear(); low_confidence_flag = false; + num_pq_pushed = 0; + num_pq_popped = 0; error_chain_arena.clear(); // Can technically be larger than pqlimit, but we need an initial guess on how many nodes we // will process from the queue. @@ -323,11 +327,12 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, std::vector next_detector_cost_tuples; pq.push({initial_cost, min_num_dets, 0, -1}); - size_t num_pq_pushed = 1; + num_pq_pushed = 1; while (!pq.empty()) { const Node node = pq.top(); pq.pop(); + ++num_pq_popped; if (node.num_dets > max_num_dets) continue; diff --git a/src/tesseract.h b/src/tesseract.h index 831e3a3..fc4173d 100644 --- a/src/tesseract.h +++ b/src/tesseract.h @@ -60,7 +60,7 @@ class Node { }; struct DetectorCostTuple { - uint32_t error_blocked; + uint8_t error_blocked; uint32_t detectors_count; }; @@ -97,6 +97,8 @@ struct TesseractDecoder { std::vector>& obs_predicted); bool low_confidence_flag = false; + size_t num_pq_pushed = 0; + size_t num_pq_popped = 0; std::vector predicted_errors_buffer; std::vector dem_error_to_error; std::vector error_to_dem_error; diff --git a/src/tesseract.pybind.h b/src/tesseract.pybind.h index 3bdf477..5781881 100644 --- a/src/tesseract.pybind.h +++ b/src/tesseract.pybind.h @@ -468,6 +468,12 @@ void add_tesseract_module(py::module& root) { "The configuration used to create this decoder.") .def_readwrite("low_confidence_flag", &TesseractDecoder::low_confidence_flag, "A flag indicating if the decoder's prediction has low confidence.") + .def_readwrite( + "num_pq_pushed", &TesseractDecoder::num_pq_pushed, + "The number of items pushed to the priority queue during the most recent decode.") + .def_readwrite( + "num_pq_popped", &TesseractDecoder::num_pq_popped, + "The number of items popped from the priority queue during the most recent decode.") .def_readwrite( "predicted_errors_buffer", &TesseractDecoder::predicted_errors_buffer, "A buffer containing the predicted errors from the most recent decode operation.") diff --git a/src/tesseract.test.cc b/src/tesseract.test.cc index 3bb34fb..ae62460 100644 --- a/src/tesseract.test.cc +++ b/src/tesseract.test.cc @@ -409,7 +409,7 @@ TEST(TesseractDetcostTest, ComparesRatiosNotRawCosts) { std::vector tuples(dec.errors.size()); // residual x = {D0, D1} - std::cout <<"dec.d2e.size() = "< +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +constexpr double INF_D = std::numeric_limits::infinity(); +constexpr double HEURISTIC_EPS = 1e-9; +constexpr double SIMPLEX_EPS = 1e-9; +constexpr double SEED_TIGHT_EPS = 1e-9; +constexpr double VIOLATION_EPS = 1e-9; +constexpr size_t VIOLATION_BATCH_SIZE = 4; + +struct UnionFind { + std::vector parent; + std::vector rank; + + explicit UnionFind(size_t n) : parent(n), rank(n, 0) { + std::iota(parent.begin(), parent.end(), 0); + } + + int find(int x) { + while (parent[x] != x) { + parent[x] = parent[parent[x]]; + x = parent[x]; + } + return x; + } + + void unite(int a, int b) { + a = find(a); + b = find(b); + if (a == b) return; + if (rank[a] < rank[b]) { + parent[a] = b; + } else if (rank[a] > rank[b]) { + parent[b] = a; + } else { + parent[b] = a; + rank[a]++; + } + } +}; + +template +std::ostream& operator<<(std::ostream& os, const std::vector& vec) { + os << "["; + bool is_first = true; + for (const auto& x : vec) { + if (!is_first) os << ", "; + is_first = false; + os << x; + } + os << "]"; + return os; +} + +template +struct IntVectorHash { + size_t operator()(const T& values) const { + return boost::hash_range(values.begin(), values.end()); + } +}; + +struct DenseSimplexResult { + bool success = false; + bool unbounded = false; + double objective = 0.0; + size_t pivots = 0; + std::vector solution; +}; + +template +double dot_on_support(const std::vector& values, const T& support) { + double total = 0.0; + for (int idx : support) total += values[(size_t)idx]; + return total; +} + +// Solves: +// maximize sum_i x_i +// subject to A x <= b +// x >= 0 +// where A is a 0/1 matrix given by row supports for a selected subset of rows. +DenseSimplexResult solve_dense_primal_packing_lp( + size_t num_vars, + const std::vector& constraints, + const std::vector& selected_rows, + const std::vector* entering_priorities = nullptr) { + DenseSimplexResult result; + result.solution.assign(num_vars, 0.0); + + const size_t num_rows = selected_rows.size(); + if (num_vars == 0) { + result.success = true; + return result; + } + if (num_rows == 0) { + result.unbounded = true; + return result; + } + + const size_t width = num_vars + num_rows + 1; + const size_t height = num_rows + 1; + std::vector tableau(height * width, 0.0); + std::vector basis(num_rows); + + for (size_t row = 0; row < num_rows; ++row) { + size_t orig_row = (size_t)selected_rows[row]; + for (int col : constraints[orig_row].local_detectors) { + tableau[row * width + (size_t)col] = 1.0; + } + tableau[row * width + num_vars + row] = 1.0; + tableau[row * width + width - 1] = constraints[orig_row].rhs; + basis[row] = num_vars + row; + if (constraints[orig_row].rhs < -SIMPLEX_EPS) { + throw std::runtime_error("Dense simplex received a negative RHS."); + } + } + for (size_t col = 0; col < num_vars; ++col) { + tableau[num_rows * width + col] = -1.0; + } + + auto pivot = [&](size_t pivot_row, size_t pivot_col) { + const double pivot_value = tableau[pivot_row * width + pivot_col]; + assert(std::abs(pivot_value) > SIMPLEX_EPS); + const double inv_pivot = 1.0 / pivot_value; + for (size_t col = 0; col < width; ++col) { + tableau[pivot_row * width + col] *= inv_pivot; + } + tableau[pivot_row * width + pivot_col] = 1.0; + + for (size_t row = 0; row < height; ++row) { + if (row == pivot_row) continue; + const double factor = tableau[row * width + pivot_col]; + if (std::abs(factor) <= SIMPLEX_EPS) { + tableau[row * width + pivot_col] = 0.0; + continue; + } + for (size_t col = 0; col < width; ++col) { + tableau[row * width + col] -= factor * tableau[pivot_row * width + col]; + } + tableau[row * width + pivot_col] = 0.0; + } + basis[pivot_row] = pivot_col; + result.pivots++; + }; + + while (true) { + size_t entering_col = width; + double entering_priority = -INF_D; + for (size_t col = 0; col + 1 < width; ++col) { + if (tableau[num_rows * width + col] >= -SIMPLEX_EPS) continue; + const bool current_is_original = entering_col < num_vars; + const bool candidate_is_original = col < num_vars; + const double candidate_priority = candidate_is_original && entering_priorities != nullptr + ? (*entering_priorities)[col] + : -INF_D; + if (entering_col == width) { + entering_col = col; + entering_priority = candidate_priority; + continue; + } + if (candidate_is_original != current_is_original) { + if (candidate_is_original) { + entering_col = col; + entering_priority = candidate_priority; + } + continue; + } + if (candidate_priority > entering_priority + SIMPLEX_EPS || + (std::abs(candidate_priority - entering_priority) <= SIMPLEX_EPS && col < entering_col)) { + entering_col = col; + entering_priority = candidate_priority; + } + } + if (entering_col == width) { + break; + } + + size_t leaving_row = num_rows; + double best_ratio = INF_D; + for (size_t row = 0; row < num_rows; ++row) { + const double coeff = tableau[row * width + entering_col]; + if (coeff <= SIMPLEX_EPS) continue; + const double ratio = tableau[row * width + width - 1] / coeff; + if (ratio + SIMPLEX_EPS < best_ratio) { + best_ratio = ratio; + leaving_row = row; + } else if (std::abs(ratio - best_ratio) <= SIMPLEX_EPS && leaving_row != num_rows && + basis[row] < basis[leaving_row]) { + leaving_row = row; + } + } + + if (leaving_row == num_rows) { + result.unbounded = true; + return result; + } + pivot(leaving_row, entering_col); + } + + for (size_t row = 0; row < num_rows; ++row) { + if (basis[row] < num_vars) { + double value = tableau[row * width + width - 1]; + if (std::abs(value) <= SIMPLEX_EPS) value = 0.0; + result.solution[basis[row]] = value; + } + } + result.objective = tableau[num_rows * width + width - 1]; + if (std::abs(result.objective) <= SIMPLEX_EPS) result.objective = 0.0; + result.success = true; + return result; +} + +template +double lookup_detector_budget(const Solution& solution, int detector) { + auto it = std::lower_bound(solution.active_detectors.begin(), solution.active_detectors.end(), + detector); + if (it == solution.active_detectors.end() || *it != detector) return 0.0; + const size_t pos = (size_t)(it - solution.active_detectors.begin()); + return solution.detector_budgets[pos]; +} + +struct SingletonComponentSolveResult { + bool success = false; + bool unbounded = false; + double objective = 0.0; + size_t reduced_constraints = 0; + size_t simplex_solves = 0; + std::vector detector_budgets; +}; + +SingletonComponentSolveResult solve_singleton_component_lp( + size_t num_local_detectors, + const std::vector& constraints, + const std::vector& cheapest_constraint_for_local_detector, + const std::vector& seed_budgets) { + SingletonComponentSolveResult result; + result.detector_budgets.assign(num_local_detectors, 0.0); + + if (num_local_detectors == 0) { + result.success = true; + return result; + } + if (constraints.empty()) { + result.unbounded = true; + return result; + } + + const double seed_total = std::accumulate(seed_budgets.begin(), seed_budgets.end(), 0.0); + + std::vector selected(constraints.size(), 0); + std::vector selected_indices; + selected_indices.reserve(std::min(constraints.size(), num_local_detectors * 2 + 4)); + + auto add_constraint = [&](int idx) { + if (idx < 0) return; + if (!selected[(size_t)idx]) { + selected[(size_t)idx] = 1; + selected_indices.push_back(idx); + } + }; + + for (size_t row = 0; row < constraints.size(); ++row) { + const auto& constraint = constraints[row]; + const double slack = constraint.rhs - dot_on_support(seed_budgets, constraint.local_detectors); + if (slack <= SEED_TIGHT_EPS * (1.0 + constraint.rhs)) { + add_constraint((int)row); + } + } + + std::vector covered(num_local_detectors, 0); + for (int idx : selected_indices) { + for (int local : constraints[(size_t)idx].local_detectors) covered[(size_t)local] = 1; + } + for (size_t local = 0; local < num_local_detectors; ++local) { + if (!covered[local]) { + const int idx = cheapest_constraint_for_local_detector[local]; + if (idx < 0) { + throw std::runtime_error("Missing seed constraint for active detector."); + } + add_constraint(idx); + for (int touched : constraints[(size_t)idx].local_detectors) covered[(size_t)touched] = 1; + } + } + + if (selected_indices.empty()) { + throw std::runtime_error("Singleton LP seed set unexpectedly empty."); + } + + size_t rounds = 0; + while (true) { + if (++rounds > constraints.size() + 1) { + throw std::runtime_error("Constraint generation exceeded the number of unique constraints."); + } + + DenseSimplexResult simplex = solve_dense_primal_packing_lp(num_local_detectors, constraints, + selected_indices, &seed_budgets); + result.simplex_solves++; + if (simplex.unbounded) { + result.unbounded = true; + return result; + } + if (!simplex.success) { + return result; + } + if (simplex.objective + 1e-7 < seed_total) { + throw std::runtime_error("Reduced singleton LP optimum fell below the projected seed bound."); + } + + double max_violation = 0.0; + std::vector> top_violated; + top_violated.reserve(VIOLATION_BATCH_SIZE); + + for (size_t row = 0; row < constraints.size(); ++row) { + if (selected[row]) continue; + const auto& constraint = constraints[row]; + const double lhs = dot_on_support(simplex.solution, constraint.local_detectors); + const double violation = lhs - constraint.rhs; + if (violation > max_violation) { + max_violation = violation; + } + if (violation <= VIOLATION_EPS * (1.0 + constraint.rhs)) continue; + + top_violated.emplace_back(violation, (int)row); + std::sort(top_violated.begin(), top_violated.end(), + [](const auto& a, const auto& b) { return a.first > b.first; }); + if (top_violated.size() > VIOLATION_BATCH_SIZE) top_violated.pop_back(); + } + + if (max_violation <= VIOLATION_EPS) { + result.success = true; + result.objective = simplex.objective; + result.reduced_constraints = selected_indices.size(); + result.detector_budgets = std::move(simplex.solution); + return result; + } + + bool added_any = false; + for (const auto& [_, idx] : top_violated) { + if (!selected[(size_t)idx]) { + add_constraint(idx); + added_any = true; + } + } + if (!added_any) { + throw std::runtime_error("Constraint generation identified violations but added no rows."); + } + } +} + +std::string heuristic_source_to_string(FTLHeuristicSource source) { + switch (source) { + case FTLHeuristicSource::kPlain: + return "plain"; + case FTLHeuristicSource::kProjected: + return "projected"; + case FTLHeuristicSource::kExact: + return "exact"; + } + return "unknown"; +} + +std::string detector_choice_policy_to_string(FTLDetectorChoicePolicy policy) { + switch (policy) { + case FTLDetectorChoicePolicy::kOrder: + return "order"; + case FTLDetectorChoicePolicy::kFewestIncidentErrors: + return "fewest_incident_errors"; + case FTLDetectorChoicePolicy::kLargestBudget: + return "largest_budget"; + case FTLDetectorChoicePolicy::kLargestBudgetPerIncident: + return "largest_budget_per_incident"; + } + return "unknown"; +} + +std::string error_order_policy_to_string(FTLErrorOrderPolicy policy) { + switch (policy) { + case FTLErrorOrderPolicy::kStatic: + return "static"; + case FTLErrorOrderPolicy::kReducedCost: + return "reduced_cost"; + } + return "unknown"; +} + +} // namespace + +std::string TesseractFTLConfig::str() { + std::stringstream ss; + ss << "TesseractFTLConfig("; + ss << "dem=DetectorErrorModel_Object, "; + ss << "det_beam=" << det_beam << ", "; + ss << "no_revisit_dets=" << no_revisit_dets << ", "; + ss << "verbose=" << verbose << ", "; + ss << "merge_errors=" << merge_errors << ", "; + ss << "pqlimit=" << pqlimit << ", "; + ss << "det_orders=" << det_orders << ", "; + ss << "det_penalty=" << det_penalty << ", "; + ss << "create_visualization=" << create_visualization << ", "; + ss << "subset_detcost_size=" << subset_detcost_size << ", "; + ss << "ignore_blocked_errors_in_heuristic=" << ignore_blocked_errors_in_heuristic << ", "; + ss << "num_min_dets_to_consider=" << num_min_dets_to_consider << ", "; + ss << "detector_choice_policy=" << detector_choice_policy_to_string(detector_choice_policy) + << ", "; + ss << "error_order_policy=" << error_order_policy_to_string(error_order_policy) << ", "; + ss << "root_det_order_count=" << root_det_order_count << ", "; + ss << "root_det_order_depth=" << root_det_order_depth << ", "; + ss << "exact_child_refine_count=" << exact_child_refine_count; + ss << ")"; + return ss.str(); +} + +void TesseractFTLStats::clear() { + *this = TesseractFTLStats{}; +} + +void TesseractFTLStats::accumulate(const TesseractFTLStats& other) { + num_pq_pushed += other.num_pq_pushed; + num_nodes_popped += other.num_nodes_popped; + max_queue_size = std::max(max_queue_size, other.max_queue_size); + heuristic_calls += other.heuristic_calls; + plain_heuristic_calls += other.plain_heuristic_calls; + projection_heuristic_calls += other.projection_heuristic_calls; + exact_refinement_calls += other.exact_refinement_calls; + lp_calls += other.lp_calls; + lp_reinserts += other.lp_reinserts; + projected_nodes_generated += other.projected_nodes_generated; + projected_nodes_refined += other.projected_nodes_refined; + total_lp_refinement_gain += other.total_lp_refinement_gain; + max_lp_refinement_gain = std::max(max_lp_refinement_gain, other.max_lp_refinement_gain); + lp_total_seconds += other.lp_total_seconds; + chain_replay_total_seconds += other.chain_replay_total_seconds; + component_build_total_seconds += other.component_build_total_seconds; + component_candidate_total_seconds += other.component_candidate_total_seconds; + component_union_total_seconds += other.component_union_total_seconds; + component_dedup_total_seconds += other.component_dedup_total_seconds; + component_finalize_total_seconds += other.component_finalize_total_seconds; + simplex_total_seconds += other.simplex_total_seconds; + projection_total_seconds += other.projection_total_seconds; + component_build_calls += other.component_build_calls; + simplex_calls += other.simplex_calls; + projection_calls += other.projection_calls; + detector_choice_calls += other.detector_choice_calls; + error_ordering_calls += other.error_ordering_calls; + total_active_detectors_popped += other.total_active_detectors_popped; + total_root_order_candidates += other.total_root_order_candidates; + total_min_detector_candidates += other.total_min_detector_candidates; + total_min_detectors_selected += other.total_min_detectors_selected; + total_min_detector_available_errors += other.total_min_detector_available_errors; + total_min_detector_blocked_errors += other.total_min_detector_blocked_errors; + total_child_candidates_considered += other.total_child_candidates_considered; + total_children_generated += other.total_children_generated; + total_children_beam_pruned += other.total_children_beam_pruned; + total_children_infeasible += other.total_children_infeasible; + total_selected_min_detector_budget += other.total_selected_min_detector_budget; + exact_child_pre_refinements += other.exact_child_pre_refinements; +} + +bool TesseractFTLDecoder::FTLNode::operator>(const FTLNode& other) const { + return f_cost > other.f_cost || (f_cost == other.f_cost && num_dets < other.num_dets); +} + +size_t TesseractFTLDecoder::DynamicBitsetHash::operator()(const boost::dynamic_bitset<>& bs) const { + return boost::hash_value(bs); +} + +TesseractFTLDecoder::TesseractFTLDecoder(TesseractFTLConfig config_) : config(config_) { + if (config.subset_detcost_size > 1) { + throw std::invalid_argument( + "tesseract_ftl singleton mode supports only subset_detcost_size of 0 or 1"); + } + + if (config.subset_detcost_size == 0) { + TesseractConfig delegate_config; + delegate_config.dem = config.dem; + delegate_config.det_beam = config.det_beam; + delegate_config.beam_climbing = config.beam_climbing; + delegate_config.no_revisit_dets = config.no_revisit_dets; + delegate_config.verbose = config.verbose; + delegate_config.merge_errors = config.merge_errors; + delegate_config.pqlimit = config.pqlimit; + delegate_config.det_orders = config.det_orders; + delegate_config.det_penalty = config.det_penalty; + delegate_config.create_visualization = config.create_visualization; + plain_delegate = std::make_unique(delegate_config); + errors = plain_delegate->errors; + num_detectors = plain_delegate->num_detectors; + num_observables = plain_delegate->num_observables; + dem_error_to_error = plain_delegate->dem_error_to_error; + error_to_dem_error = plain_delegate->error_to_dem_error; + return; + } + + std::vector dem_error_map(config.dem.flattened().count_errors()); + std::iota(dem_error_map.begin(), dem_error_map.end(), 0); + + if (config.merge_errors) { + std::vector merge_map; + config.dem = common::merge_indistinguishable_errors(config.dem, merge_map); + common::chain_error_maps(dem_error_map, merge_map); + } + + std::vector nonzero_map; + config.dem = common::remove_zero_probability_errors(config.dem, nonzero_map); + common::chain_error_maps(dem_error_map, nonzero_map); + + dem_error_to_error = std::move(dem_error_map); + error_to_dem_error = common::invert_error_map(dem_error_to_error, config.dem.count_errors()); + + if (config.det_orders.empty()) { + config.det_orders.emplace_back(config.dem.count_detectors()); + std::iota(config.det_orders[0].begin(), config.det_orders[0].end(), 0); + } else { + for (const auto& order : config.det_orders) { + if (order.size() != config.dem.count_detectors()) { + throw std::invalid_argument( + "Each detector order list must have a size equal to the number of detectors."); + } + } + } + if (config.det_orders.empty()) { + throw std::runtime_error("Detector order list must not be empty."); + } + + errors = get_errors_from_dem(config.dem.flattened()); + num_detectors = config.dem.count_detectors(); + num_errors = config.dem.count_errors(); + num_observables = config.dem.count_observables(); + + initialize_structures(num_detectors); + + if (config.create_visualization) { + auto detectors = get_detector_coords(config.dem); + visualizer.add_detector_coords(detectors); + visualizer.add_errors(errors); + } +} + +TesseractFTLDecoder::~TesseractFTLDecoder() = default; + +void TesseractFTLDecoder::initialize_structures(size_t num_detectors_) { + d2e.resize(num_detectors_); + edets.resize(num_errors); + error_costs.resize(num_errors); + candidate_error_marks.assign(num_errors, 0); + candidate_error_mark_epoch = 1; + + for (size_t ei = 0; ei < num_errors; ++ei) { + edets[ei] = errors[ei].symptom.detectors; + for (int d : edets[ei]) { + d2e[(size_t)d].push_back((int)ei); + } + error_costs[ei] = {errors[ei].likelihood_cost, + errors[ei].likelihood_cost / errors[ei].symptom.detectors.size()}; + } + + for (size_t d = 0; d < num_detectors_; ++d) { + std::sort(d2e[d].begin(), d2e[d].end(), [this](int a, int b) { + return error_costs[(size_t)a].min_cost < error_costs[(size_t)b].min_cost; + }); + } +} + +void TesseractFTLDecoder::flip_detectors_and_block_errors( + size_t detector_order, int64_t error_chain_idx, boost::dynamic_bitset<>& detectors, + std::vector& blocked_flags) const { + (void)detector_order; + int64_t walker_idx = error_chain_idx; + while (walker_idx != -1) { + const auto& node = error_chain_arena[(size_t)walker_idx]; + const size_t ei = node.error_index; + const size_t min_detector = node.min_detector; + + for (int oei : d2e[min_detector]) { + blocked_flags[(size_t)oei] = 1; + if ((size_t)oei == ei) break; + } + for (int d : edets[ei]) detectors[(size_t)d] = !detectors[(size_t)d]; + walker_idx = node.parent_idx; + } +} + +void block_errors_from_chain(const std::vector& error_chain_arena, + const std::vector>& d2e, int64_t error_chain_idx, + std::vector& blocked_flags) { + int64_t walker_idx = error_chain_idx; + while (walker_idx != -1) { + const auto& node = error_chain_arena[(size_t)walker_idx]; + const size_t ei = node.error_index; + const size_t min_detector = node.min_detector; + for (int oei : d2e[min_detector]) { + blocked_flags[(size_t)oei] = 1; + if ((size_t)oei == ei) break; + } + walker_idx = node.parent_idx; + } +} + +TesseractFTLDecoder::SingletonBuildResult TesseractFTLDecoder::build_singleton_components( + const boost::dynamic_bitset<>& detectors, const std::vector& blocked_flags) { + SingletonBuildResult result; + const auto candidate_start_time = std::chrono::high_resolution_clock::now(); + + std::vector active_detectors; + active_detectors.reserve(detectors.count()); + std::vector detector_to_active_pos(num_detectors, -1); + for (size_t detector = detectors.find_first(); detector != boost::dynamic_bitset<>::npos; + detector = detectors.find_next(detector)) { + detector_to_active_pos[detector] = (int)active_detectors.size(); + active_detectors.push_back((int)detector); + } + if (active_detectors.empty()) return result; + + if (candidate_error_mark_epoch == std::numeric_limits::max()) { + std::fill(candidate_error_marks.begin(), candidate_error_marks.end(), 0); + candidate_error_mark_epoch = 1; + } + const uint64_t mark_epoch = candidate_error_mark_epoch++; + std::vector candidate_errors; + for (int detector : active_detectors) { + for (int ei : d2e[(size_t)detector]) { + if (blocked_flags[(size_t)ei]) continue; + if (candidate_error_marks[(size_t)ei] == mark_epoch) continue; + candidate_error_marks[(size_t)ei] = mark_epoch; + candidate_errors.push_back(ei); + } + } + const auto candidate_stop_time = std::chrono::high_resolution_clock::now(); + stats.component_candidate_total_seconds += std::chrono::duration_cast( + candidate_stop_time - candidate_start_time) + .count() / + 1e6; + + const auto union_start_time = std::chrono::high_resolution_clock::now(); + UnionFind uf(active_detectors.size()); + std::vector has_available(active_detectors.size(), 0); + + for (int ei : candidate_errors) { + int first_active = -1; + for (int detector : edets[(size_t)ei]) { + const int active_pos = detector_to_active_pos[(size_t)detector]; + if (active_pos < 0) continue; + has_available[(size_t)active_pos] = 1; + if (first_active < 0) { + first_active = active_pos; + } else { + uf.unite(first_active, active_pos); + } + } + } + + for (size_t active_pos = 0; active_pos < active_detectors.size(); ++active_pos) { + if (!has_available[active_pos]) { + result.feasible = false; + return result; + } + } + + std::vector root_to_component_index(active_detectors.size(), -1); + std::vector active_pos_to_component(active_detectors.size(), -1); + std::vector active_pos_to_local(active_detectors.size(), -1); + result.components.reserve(active_detectors.size()); + for (int active_pos = 0; active_pos < (int)active_detectors.size(); ++active_pos) { + const int root = uf.find(active_pos); + int& component_index = root_to_component_index[(size_t)root]; + if (component_index < 0) { + component_index = (int)result.components.size(); + result.components.emplace_back(); + } + auto& component = result.components[(size_t)component_index]; + active_pos_to_component[(size_t)active_pos] = component_index; + active_pos_to_local[(size_t)active_pos] = (int)component.detectors.size(); + component.detectors.push_back(active_detectors[(size_t)active_pos]); + } + const auto union_stop_time = std::chrono::high_resolution_clock::now(); + stats.component_union_total_seconds += + std::chrono::duration_cast(union_stop_time - union_start_time) + .count() / + 1e6; + + const auto dedup_start_time = std::chrono::high_resolution_clock::now(); + std::vector, double, IntVectorHash>>> + min_rhs_by_pattern(result.components.size()); + std::vector local_hits; + local_hits.reserve(16); + + for (int ei : candidate_errors) { + int component_index = -1; + local_hits.clear(); + + for (int detector : edets[(size_t)ei]) { + const int active_pos = detector_to_active_pos[(size_t)detector]; + if (active_pos < 0) continue; + if (component_index < 0) { + component_index = active_pos_to_component[(size_t)active_pos]; + } else { + assert(component_index == active_pos_to_component[(size_t)active_pos]); + } + local_hits.push_back(active_pos_to_local[(size_t)active_pos]); + } + + if (component_index < 0) continue; + const double rhs = errors[(size_t)ei].likelihood_cost; + auto& rhs_map = min_rhs_by_pattern[(size_t)component_index]; + auto it = rhs_map.find(local_hits); + if (it == rhs_map.end() || rhs < it->second) { + rhs_map[local_hits] = rhs; + } + } + const auto dedup_stop_time = std::chrono::high_resolution_clock::now(); + stats.component_dedup_total_seconds += + std::chrono::duration_cast(dedup_stop_time - dedup_start_time) + .count() / + 1e6; + + const auto finalize_start_time = std::chrono::high_resolution_clock::now(); + for (size_t component_index = 0; component_index < result.components.size(); ++component_index) { + auto& component = result.components[component_index]; + const auto& rhs_map = min_rhs_by_pattern[component_index]; + component.constraints.reserve(rhs_map.size()); + for (const auto& [local_hits, rhs] : rhs_map) { + component.constraints.push_back({local_hits, rhs}); + } + std::sort(component.constraints.begin(), component.constraints.end(), + [](const auto& a, const auto& b) { + if (a.local_detectors.size() != b.local_detectors.size()) { + return a.local_detectors.size() < b.local_detectors.size(); + } + if (a.local_detectors != b.local_detectors) { + return a.local_detectors < b.local_detectors; + } + return a.rhs < b.rhs; + }); + + component.cheapest_constraint_for_local_detector.assign(component.detectors.size(), -1); + std::vector cheapest_rhs(component.detectors.size(), INF_D); + for (size_t constraint_index = 0; constraint_index < component.constraints.size(); + ++constraint_index) { + const auto& constraint = component.constraints[constraint_index]; + for (int local_detector : constraint.local_detectors) { + if (constraint.rhs < cheapest_rhs[(size_t)local_detector]) { + cheapest_rhs[(size_t)local_detector] = constraint.rhs; + component.cheapest_constraint_for_local_detector[(size_t)local_detector] = + (int)constraint_index; + } + } + } + for (size_t local = 0; local < component.detectors.size(); ++local) { + if (component.cheapest_constraint_for_local_detector[local] < 0) { + result.feasible = false; + result.components.clear(); + return result; + } + } + } + const auto finalize_stop_time = std::chrono::high_resolution_clock::now(); + stats.component_finalize_total_seconds += std::chrono::duration_cast( + finalize_stop_time - finalize_start_time) + .count() / + 1e6; + + return result; +} + +TesseractFTLDecoder::ExactSubsetSolution TesseractFTLDecoder::solve_exact_subset_lp( + const boost::dynamic_bitset<>& detectors, const std::vector& blocked_flags, + int64_t warm_solution_idx) { + stats.heuristic_calls++; + stats.exact_refinement_calls++; + const auto start_time = std::chrono::high_resolution_clock::now(); + + ExactSubsetSolution solution; + std::vector ignored_blocked_flags; + const std::vector* effective_blocked_flags = &blocked_flags; + if (config.ignore_blocked_errors_in_heuristic) { + ignored_blocked_flags.assign(num_errors, 0); + effective_blocked_flags = &ignored_blocked_flags; + } + const auto build_start_time = std::chrono::high_resolution_clock::now(); + const auto build = build_singleton_components(detectors, *effective_blocked_flags); + const auto build_stop_time = std::chrono::high_resolution_clock::now(); + stats.component_build_calls++; + stats.component_build_total_seconds += + std::chrono::duration_cast(build_stop_time - build_start_time) + .count() / + 1e6; + if (!build.feasible) { + solution.value = INF_D; + const auto stop_time = std::chrono::high_resolution_clock::now(); + stats.lp_total_seconds += + std::chrono::duration_cast(stop_time - start_time).count() / 1e6; + return solution; + } + if (build.components.empty()) { + solution.value = 0.0; + const auto stop_time = std::chrono::high_resolution_clock::now(); + stats.lp_total_seconds += + std::chrono::duration_cast(stop_time - start_time).count() / 1e6; + return solution; + } + + const ExactSubsetSolution* warm_solution = + warm_solution_idx >= 0 ? &exact_solution_arena[(size_t)warm_solution_idx] : nullptr; + solution.value = 0.0; + solution.num_components = build.components.size(); + std::vector> detector_budget_pairs; + detector_budget_pairs.reserve(detectors.count()); + size_t warm_pos = 0; + + for (const auto& component : build.components) { + std::vector seed_budgets(component.detectors.size(), 0.0); + if (warm_solution != nullptr) { + for (size_t local = 0; local < component.detectors.size(); ++local) { + int det = component.detectors[local]; + while (warm_pos < warm_solution->active_detectors.size() && + warm_solution->active_detectors[warm_pos] < det) { + ++warm_pos; + } + if (warm_pos < warm_solution->active_detectors.size() && + warm_solution->active_detectors[warm_pos] == det) { + seed_budgets[local] = warm_solution->detector_budgets[warm_pos]; + } + } + } + const auto simplex_start_time = std::chrono::high_resolution_clock::now(); + const auto component_result = solve_singleton_component_lp( + component.detectors.size(), component.constraints, + component.cheapest_constraint_for_local_detector, seed_budgets); + const auto simplex_stop_time = std::chrono::high_resolution_clock::now(); + stats.simplex_calls++; + stats.simplex_total_seconds += std::chrono::duration_cast( + simplex_stop_time - simplex_start_time) + .count() / + 1e6; + stats.lp_calls += component_result.simplex_solves; + + if (component_result.unbounded) { + throw std::runtime_error("Singleton custom LP became unbounded."); + } + if (!component_result.success) { + throw std::runtime_error("Singleton custom LP failed."); + } + + solution.value += component_result.objective; + solution.num_active_subsets += component.detectors.size(); + solution.num_variables += component.detectors.size(); + solution.num_constraints += component_result.reduced_constraints; + + for (size_t local = 0; local < component.detectors.size(); ++local) { + detector_budget_pairs.emplace_back(component.detectors[local], + component_result.detector_budgets[local]); + } + } + + if (detector_budget_pairs.size() > 1) { + std::sort(detector_budget_pairs.begin(), detector_budget_pairs.end(), + [](const auto& a, const auto& b) { return a.first < b.first; }); + } + solution.active_detectors.reserve(detector_budget_pairs.size()); + solution.detector_budgets.reserve(detector_budget_pairs.size()); + for (const auto& [detector, budget] : detector_budget_pairs) { + solution.active_detectors.push_back(detector); + solution.detector_budgets.push_back(budget); + } + + const auto stop_time = std::chrono::high_resolution_clock::now(); + stats.lp_total_seconds += + std::chrono::duration_cast(stop_time - start_time).count() / 1e6; + return solution; +} + +double TesseractFTLDecoder::project_from_exact_solution(const ExactSubsetSolution& solution, + const boost::dynamic_bitset<>& detectors, + const std::vector& blocked_flags) { + stats.heuristic_calls++; + stats.projection_heuristic_calls++; + const auto start_time = std::chrono::high_resolution_clock::now(); + stats.projection_calls++; + + double total = 0.0; + size_t budget_pos = 0; + const std::vector* effective_blocked_flags = &blocked_flags; + std::vector ignored_blocked_flags; + if (config.ignore_blocked_errors_in_heuristic) { + ignored_blocked_flags.assign(num_errors, 0); + effective_blocked_flags = &ignored_blocked_flags; + } + for (size_t detector = detectors.find_first(); detector != boost::dynamic_bitset<>::npos; + detector = detectors.find_next(detector)) { + bool has_available = false; + for (int ei : d2e[detector]) { + if (!(*effective_blocked_flags)[(size_t)ei]) { + has_available = true; + break; + } + } + if (!has_available) { + const auto stop_time = std::chrono::high_resolution_clock::now(); + stats.projection_total_seconds += + std::chrono::duration_cast(stop_time - start_time).count() / + 1e6; + return INF_D; + } + + while (budget_pos < solution.active_detectors.size() && + solution.active_detectors[budget_pos] < (int)detector) { + ++budget_pos; + } + if (budget_pos < solution.active_detectors.size() && + solution.active_detectors[budget_pos] == (int)detector) { + total += solution.detector_budgets[budget_pos]; + } + } + const auto stop_time = std::chrono::high_resolution_clock::now(); + stats.projection_total_seconds += + std::chrono::duration_cast(stop_time - start_time).count() / 1e6; + return total; +} + +std::vector TesseractFTLDecoder::select_min_detectors( + const boost::dynamic_bitset<>& detectors, const std::vector& blocked_flags, + size_t detector_order, size_t depth, const ExactSubsetSolution& exact_solution) { + stats.detector_choice_calls++; + stats.total_active_detectors_popped += detectors.count(); + + struct CandidateDetector { + size_t detector; + size_t order_rank; + size_t available_errors; + double budget; + }; + + const size_t order_count = depth < config.root_det_order_depth + ? std::min(config.root_det_order_count, config.det_orders.size()) + : 1; + std::vector seen(num_detectors, 0); + std::vector candidates; + candidates.reserve(detectors.count()); + + size_t discovery_rank = 0; + for (size_t order_offset = 0; order_offset < order_count; ++order_offset) { + size_t taken_from_order = 0; + const size_t order_index = (detector_order + order_offset) % config.det_orders.size(); + for (size_t offset = 0; offset < num_detectors; ++offset) { + const size_t detector = config.det_orders[order_index][offset]; + if (!detectors[detector]) continue; + if (!seen[detector]) { + seen[detector] = 1; + size_t available_errors = 0; + for (int ei : d2e[detector]) { + if (!blocked_flags[(size_t)ei]) { + available_errors++; + } + } + candidates.push_back({detector, discovery_rank++, available_errors, + lookup_detector_budget(exact_solution, (int)detector)}); + } + taken_from_order++; + if (config.detector_choice_policy == FTLDetectorChoicePolicy::kOrder && + taken_from_order >= config.num_min_dets_to_consider) { + break; + } + } + } + + stats.total_root_order_candidates += candidates.size(); + stats.total_min_detector_candidates += candidates.size(); + + if (config.detector_choice_policy != FTLDetectorChoicePolicy::kOrder) { + std::stable_sort(candidates.begin(), candidates.end(), [&](const auto& a, const auto& b) { + switch (config.detector_choice_policy) { + case FTLDetectorChoicePolicy::kOrder: + break; + case FTLDetectorChoicePolicy::kFewestIncidentErrors: + if (a.available_errors != b.available_errors) { + return a.available_errors < b.available_errors; + } + break; + case FTLDetectorChoicePolicy::kLargestBudget: + if (a.budget != b.budget) return a.budget > b.budget; + break; + case FTLDetectorChoicePolicy::kLargestBudgetPerIncident: { + const double a_score = + a.available_errors == 0 ? INF_D : a.budget / (double)a.available_errors; + const double b_score = + b.available_errors == 0 ? INF_D : b.budget / (double)b.available_errors; + if (a_score != b_score) return a_score > b_score; + break; + } + } + if (a.order_rank != b.order_rank) return a.order_rank < b.order_rank; + return a.detector < b.detector; + }); + } + + std::vector selected; + selected.reserve(std::min(config.num_min_dets_to_consider, candidates.size())); + for (const auto& candidate : candidates) { + selected.push_back(candidate.detector); + stats.total_min_detectors_selected++; + stats.total_min_detector_available_errors += candidate.available_errors; + stats.total_selected_min_detector_budget += candidate.budget; + if (selected.size() >= config.num_min_dets_to_consider) break; + } + return selected; +} + +std::vector TesseractFTLDecoder::order_candidate_errors( + size_t min_detector, const boost::dynamic_bitset<>& detectors, + const std::vector& blocked_flags, const ExactSubsetSolution& exact_solution) { + stats.error_ordering_calls++; + + std::vector ordered_errors; + ordered_errors.reserve(d2e[min_detector].size()); + + if (config.error_order_policy == FTLErrorOrderPolicy::kStatic) { + for (int ei : d2e[min_detector]) { + if (blocked_flags[(size_t)ei]) { + stats.total_min_detector_blocked_errors++; + continue; + } + ordered_errors.push_back(ei); + } + return ordered_errors; + } + + struct CandidateError { + int error_index; + size_t order_rank; + double reduced_cost; + int net_det_delta; + }; + std::vector candidates; + candidates.reserve(d2e[min_detector].size()); + size_t order_rank = 0; + for (int ei : d2e[min_detector]) { + if (blocked_flags[(size_t)ei]) { + stats.total_min_detector_blocked_errors++; + continue; + } + double covered_budget = 0.0; + int net_det_delta = 0; + for (int detector : edets[(size_t)ei]) { + if (detectors[(size_t)detector]) { + covered_budget += lookup_detector_budget(exact_solution, detector); + net_det_delta--; + } else { + net_det_delta++; + } + } + candidates.push_back( + {ei, order_rank++, errors[(size_t)ei].likelihood_cost - covered_budget, net_det_delta}); + } + std::stable_sort(candidates.begin(), candidates.end(), [&](const auto& a, const auto& b) { + if (a.reduced_cost != b.reduced_cost) return a.reduced_cost < b.reduced_cost; + if (a.net_det_delta != b.net_det_delta) return a.net_det_delta < b.net_det_delta; + return a.order_rank < b.order_rank; + }); + for (const auto& candidate : candidates) ordered_errors.push_back(candidate.error_index); + return ordered_errors; +} + +void TesseractFTLDecoder::reset_decode_state() { + low_confidence_flag = false; + predicted_errors_buffer.clear(); + error_chain_arena.clear(); + detector_state_arena.clear(); + exact_solution_arena.clear(); + exact_solution_cache.clear(); + stats.clear(); +} + +void TesseractFTLDecoder::decode_to_errors(const std::vector& detections) { + if (config.verbose) { + std::cout << "shot"; + for (const uint64_t& d : detections) { + std::cout << " D" << d; + } + std::cout << std::endl; + } + if (plain_delegate) { + plain_delegate->decode_to_errors(detections); + predicted_errors_buffer = plain_delegate->predicted_errors_buffer; + low_confidence_flag = plain_delegate->low_confidence_flag; + stats.clear(); + return; + } + + std::vector best_errors; + double best_cost = std::numeric_limits::max(); + bool any_success = false; + TesseractFTLStats aggregate_stats; + stats.clear(); + + if (config.beam_climbing) { + int beam = 0; + int detector_order = 0; + for (int trial = 0; trial < std::max(config.det_beam + 1, int(config.det_orders.size())); + ++trial) { + decode_to_errors(detections, (size_t)detector_order, (size_t)beam); + aggregate_stats.accumulate(stats); + const double local_cost = cost_from_errors(predicted_errors_buffer); + if (!low_confidence_flag && local_cost < best_cost) { + best_errors = predicted_errors_buffer; + best_cost = local_cost; + any_success = true; + } + if (config.verbose) { + std::cout << "for detector_order " << detector_order << " beam " << beam + << " got low confidence " << low_confidence_flag << " and cost " << local_cost + << ". Best cost so far: " << best_cost << std::endl; + } + beam = (beam + 1) % (config.det_beam + 1); + detector_order = (detector_order + 1) % config.det_orders.size(); + } + } else { + for (size_t detector_order = 0; detector_order < config.det_orders.size(); ++detector_order) { + decode_to_errors(detections, detector_order, config.det_beam); + aggregate_stats.accumulate(stats); + const double local_cost = cost_from_errors(predicted_errors_buffer); + if (!low_confidence_flag && local_cost < best_cost) { + best_errors = predicted_errors_buffer; + best_cost = local_cost; + any_success = true; + } + if (config.verbose) { + std::cout << "for detector_order " << detector_order << " beam " << config.det_beam + << " got low confidence " << low_confidence_flag << " and cost " << local_cost + << ". Best cost so far: " << best_cost << std::endl; + } + } + } + predicted_errors_buffer = best_errors; + low_confidence_flag = !any_success; + stats = aggregate_stats; +} + +void TesseractFTLDecoder::decode_to_errors(const std::vector& detections, + size_t detector_order, size_t detector_beam) { + if (plain_delegate) { + plain_delegate->decode_to_errors(detections, detector_order, detector_beam); + predicted_errors_buffer = plain_delegate->predicted_errors_buffer; + low_confidence_flag = plain_delegate->low_confidence_flag; + return; + } + + reset_decode_state(); + if (config.pqlimit != std::numeric_limits::max()) { + const size_t reserve_size = std::min(config.pqlimit, 5000000); + error_chain_arena.reserve(reserve_size); + detector_state_arena.reserve(reserve_size + 1); + exact_solution_arena.reserve(reserve_size / 4 + 1); + } + + std::priority_queue, std::greater> pq; + std::vector, DynamicBitsetHash>> visited_detectors( + num_detectors + 1); + + boost::dynamic_bitset<> initial_detectors(num_detectors, false); + std::vector initial_blocked_flags(num_errors, 0); + for (size_t detector : detections) { + if (detector >= num_detectors) { + throw std::runtime_error("Symptom references detector >= num_detectors"); + } + initial_detectors[detector] = true; + } + + size_t min_num_dets = detections.size(); + size_t max_num_dets = + detector_beam > num_detectors - min_num_dets ? num_detectors : min_num_dets + detector_beam; + + FTLNode root; + root.g_cost = 0.0; + root.num_dets = min_num_dets; + root.depth = 0; + root.error_chain_idx = -1; + detector_state_arena.push_back(initial_detectors); + root.detector_state_idx = 0; + root.warm_solution_idx = -1; + root.exact_solution_idx = -1; + + ExactSubsetSolution root_exact = + solve_exact_subset_lp(initial_detectors, initial_blocked_flags, -1); + if (root_exact.value == INF_D) { + low_confidence_flag = true; + return; + } + exact_solution_arena.push_back(std::move(root_exact)); + root.exact_solution_idx = (int64_t)exact_solution_arena.size() - 1; + if (config.ignore_blocked_errors_in_heuristic) { + exact_solution_cache.emplace(initial_detectors, root.exact_solution_idx); + } + root.f_cost = exact_solution_arena.back().value; + root.h_cost = exact_solution_arena.back().value; + root.exact_refined = true; + root.heuristic_source = FTLHeuristicSource::kExact; + pq.push(root); + stats.num_pq_pushed = 1; + stats.max_queue_size = 1; + + while (!pq.empty()) { + stats.max_queue_size = std::max(stats.max_queue_size, pq.size()); + FTLNode node = pq.top(); + pq.pop(); + stats.num_nodes_popped++; + + if (node.num_dets > max_num_dets) continue; + + boost::dynamic_bitset<> detectors = detector_state_arena[(size_t)node.detector_state_idx]; + std::vector blocked_flags(num_errors, 0); + const auto chain_start_time = std::chrono::high_resolution_clock::now(); + block_errors_from_chain(error_chain_arena, d2e, node.error_chain_idx, blocked_flags); + const auto chain_stop_time = std::chrono::high_resolution_clock::now(); + stats.chain_replay_total_seconds += + std::chrono::duration_cast(chain_stop_time - chain_start_time) + .count() / + 1e6; + + if (config.verbose) { + const size_t projected_unrefined = + stats.projected_nodes_generated - stats.projected_nodes_refined; + std::cout.precision(13); + std::cout << "nodes_popped=" << stats.num_nodes_popped << " len(pq)=" << pq.size() + << " nodes_pushed=" << stats.num_pq_pushed << " lp_calls=" << stats.lp_calls + << " lp_reinserts=" << stats.lp_reinserts + << " proj_generated=" << stats.projected_nodes_generated + << " proj_refined=" << stats.projected_nodes_refined + << " proj_unrefined_so_far=" << projected_unrefined << " num_dets=" << node.num_dets + << " max_num_dets=" << max_num_dets << " f=" << node.f_cost << " g=" << node.g_cost + << " h=" << node.h_cost + << " h_source=" << heuristic_source_to_string(node.heuristic_source) + << " exact_refined=" << node.exact_refined << std::endl; + } + + if (node.num_dets == 0) { + predicted_errors_buffer.resize(node.depth); + int64_t walker_idx = node.error_chain_idx; + for (size_t i = 0; i < node.depth; ++i) { + predicted_errors_buffer[node.depth - 1 - i] = + error_to_dem_error[error_chain_arena[(size_t)walker_idx].error_index]; + walker_idx = error_chain_arena[(size_t)walker_idx].parent_idx; + } + if (config.verbose) { + std::cout << "Decoding complete. Cost: " << node.g_cost + << " num_pq_pushed = " << stats.num_pq_pushed << std::endl; + } + return; + } + + if (node.num_dets < min_num_dets) { + min_num_dets = node.num_dets; + const size_t next_max_num_dets = detector_beam > num_detectors - min_num_dets + ? num_detectors + : min_num_dets + detector_beam; + if (config.no_revisit_dets) { + for (size_t count = next_max_num_dets + 1; count <= max_num_dets; ++count) { + visited_detectors[count].clear(); + } + } + max_num_dets = std::min(max_num_dets, next_max_num_dets); + } + + if (!node.exact_refined) { + const double prev_h = node.h_cost; + const FTLHeuristicSource prev_source = node.heuristic_source; + bool used_cached_exact_solution = false; + int64_t cached_exact_solution_idx = -1; + if (config.ignore_blocked_errors_in_heuristic) { + auto it = exact_solution_cache.find(detectors); + if (it != exact_solution_cache.end()) { + used_cached_exact_solution = true; + cached_exact_solution_idx = it->second; + } + } + if (prev_source == FTLHeuristicSource::kProjected) stats.projected_nodes_refined++; + if (used_cached_exact_solution) { + node.exact_solution_idx = cached_exact_solution_idx; + node.h_cost = exact_solution_arena[(size_t)cached_exact_solution_idx].value; + const double delta = node.h_cost - prev_h; + if (node.h_cost + 1e-7 < prev_h) { + throw std::runtime_error("Cached singleton lower bound fell below stored lower bound."); + } + stats.total_lp_refinement_gain += delta; + stats.max_lp_refinement_gain = std::max(stats.max_lp_refinement_gain, delta); + node.f_cost = node.g_cost + node.h_cost; + node.exact_refined = true; + node.heuristic_source = FTLHeuristicSource::kExact; + if (delta > HEURISTIC_EPS) { + stats.lp_reinserts++; + pq.push(node); + stats.num_pq_pushed++; + if (stats.num_pq_pushed > config.pqlimit) { + low_confidence_flag = true; + return; + } + continue; + } + } else { + ExactSubsetSolution exact_solution = + solve_exact_subset_lp(detectors, blocked_flags, node.warm_solution_idx); + if (exact_solution.value == INF_D) { + if (config.verbose) { + std::cout << " lp_refine exact_h=INF discarded=true" << std::endl; + } + continue; + } + if (exact_solution.value + 1e-7 < prev_h) { + throw std::runtime_error("Exact singleton lower bound fell below stored lower bound."); + } + const double delta = exact_solution.value - prev_h; + stats.total_lp_refinement_gain += delta; + stats.max_lp_refinement_gain = std::max(stats.max_lp_refinement_gain, delta); + exact_solution_arena.push_back(std::move(exact_solution)); + node.exact_solution_idx = (int64_t)exact_solution_arena.size() - 1; + if (config.ignore_blocked_errors_in_heuristic) { + exact_solution_cache.emplace(detectors, node.exact_solution_idx); + } + node.h_cost = exact_solution_arena.back().value; + node.f_cost = node.g_cost + node.h_cost; + node.exact_refined = true; + node.heuristic_source = FTLHeuristicSource::kExact; + if (config.verbose) { + std::cout << " lp_refine approx_h=" << prev_h << " exact_h=" << node.h_cost + << " delta=" << delta << " vars=" << exact_solution_arena.back().num_variables + << " constraints=" << exact_solution_arena.back().num_constraints + << " reinserted=" << (delta > HEURISTIC_EPS) << std::endl; + } + if (delta > HEURISTIC_EPS) { + stats.lp_reinserts++; + pq.push(node); + stats.num_pq_pushed++; + if (stats.num_pq_pushed > config.pqlimit) { + low_confidence_flag = true; + return; + } + continue; + } + } + } + + if (config.no_revisit_dets && !visited_detectors[node.num_dets].insert(detectors).second) { + continue; + } + + const auto& exact_solution = exact_solution_arena[(size_t)node.exact_solution_idx]; + std::vector min_detectors = + select_min_detectors(detectors, blocked_flags, detector_order, node.depth, exact_solution); + if (min_detectors.empty()) { + throw std::runtime_error("Failed to select an active min detector for a non-terminal node."); + } + + size_t children_generated = 0; + size_t children_projected = 0; + size_t children_beam_pruned = 0; + size_t children_infeasible = 0; + size_t children_exactly_refined = 0; + + for (size_t min_detector : min_detectors) { + std::vector prefix_blocked = blocked_flags; + const std::vector ordered_errors = + order_candidate_errors(min_detector, detectors, blocked_flags, exact_solution); + for (int ei : ordered_errors) { + prefix_blocked[(size_t)ei] = 1; + stats.total_child_candidates_considered++; + + boost::dynamic_bitset<> child_detectors = detectors; + size_t child_num_dets = node.num_dets; + for (int detector : edets[(size_t)ei]) { + if (detectors[(size_t)detector]) { + --child_num_dets; + } else { + ++child_num_dets; + } + child_detectors.flip((size_t)detector); + } + if (child_num_dets > max_num_dets) { + children_beam_pruned++; + stats.total_children_beam_pruned++; + continue; + } + + double child_h = + project_from_exact_solution(exact_solution, child_detectors, prefix_blocked); + stats.projected_nodes_generated++; + children_projected++; + if (child_h == INF_D) { + children_infeasible++; + stats.total_children_infeasible++; + continue; + } + + error_chain_arena.emplace_back(); + auto& chain_node = error_chain_arena.back(); + chain_node.error_index = (size_t)ei; + chain_node.min_detector = min_detector; + chain_node.parent_idx = node.error_chain_idx; + + FTLNode child; + child.g_cost = node.g_cost + errors[(size_t)ei].likelihood_cost; + child.h_cost = child_h; + child.f_cost = child.g_cost + child.h_cost; + child.num_dets = child_num_dets; + child.depth = node.depth + 1; + child.error_chain_idx = (int64_t)error_chain_arena.size() - 1; + detector_state_arena.push_back(std::move(child_detectors)); + child.detector_state_idx = (int64_t)detector_state_arena.size() - 1; + child.warm_solution_idx = node.exact_solution_idx; + child.exact_solution_idx = -1; + child.exact_refined = false; + child.heuristic_source = FTLHeuristicSource::kProjected; + + if (config.exact_child_refine_count > 0 && + children_exactly_refined < config.exact_child_refine_count) { + ExactSubsetSolution child_exact = + solve_exact_subset_lp(detector_state_arena[(size_t)child.detector_state_idx], + prefix_blocked, child.warm_solution_idx); + if (child_exact.value == INF_D) { + children_infeasible++; + stats.total_children_infeasible++; + continue; + } + exact_solution_arena.push_back(std::move(child_exact)); + child.exact_solution_idx = (int64_t)exact_solution_arena.size() - 1; + child.h_cost = exact_solution_arena.back().value; + child.f_cost = child.g_cost + child.h_cost; + child.exact_refined = true; + child.heuristic_source = FTLHeuristicSource::kExact; + children_exactly_refined++; + stats.exact_child_pre_refinements++; + } + + pq.push(child); + stats.num_pq_pushed++; + children_generated++; + stats.total_children_generated++; + if (stats.num_pq_pushed > config.pqlimit) { + low_confidence_flag = true; + return; + } + } + } + + if (config.verbose) { + const size_t projected_unrefined = + stats.projected_nodes_generated - stats.projected_nodes_refined; + std::cout << " expanded children_generated=" << children_generated + << " children_projected=" << children_projected + << " beam_pruned=" << children_beam_pruned << " infeasible=" << children_infeasible + << " lp_calls=" << stats.lp_calls + << " proj_unrefined_so_far=" << projected_unrefined << std::endl; + } + } + + if (config.verbose) { + std::cout << "Decoding failed to converge within beam limit." << std::endl; + } + low_confidence_flag = true; +} + +double TesseractFTLDecoder::cost_from_errors(const std::vector& predicted_errors) const { + if (plain_delegate) return plain_delegate->cost_from_errors(predicted_errors); + double total_cost = 0.0; + for (size_t dem_error_index : predicted_errors) { + const size_t error_index = dem_error_to_error[dem_error_index]; + if (error_index == std::numeric_limits::max()) { + throw std::invalid_argument("error index does not map to a retained decoder error"); + } + total_cost += errors[error_index].likelihood_cost; + } + return total_cost; +} + +std::vector TesseractFTLDecoder::get_flipped_observables( + const std::vector& predicted_errors) const { + if (plain_delegate) return plain_delegate->get_flipped_observables(predicted_errors); + std::vector toggled(num_observables, 0); + for (size_t dem_error_index : predicted_errors) { + const size_t error_index = dem_error_to_error[dem_error_index]; + if (error_index == std::numeric_limits::max()) { + throw std::invalid_argument("error index does not map to a retained decoder error"); + } + for (int obs_index : errors[error_index].symptom.observables) { + toggled[(size_t)obs_index] ^= 1; + } + } + std::vector flipped_observables; + flipped_observables.reserve(num_observables); + for (size_t obs_index = 0; obs_index < num_observables; ++obs_index) { + if (toggled[obs_index]) flipped_observables.push_back((int)obs_index); + } + return flipped_observables; +} + +std::vector TesseractFTLDecoder::decode(const std::vector& detections) { + decode_to_errors(detections); + return get_flipped_observables(predicted_errors_buffer); +} + +void TesseractFTLDecoder::decode_shots(std::vector& shots, + std::vector>& obs_predicted) { + obs_predicted.resize(shots.size()); + for (size_t i = 0; i < shots.size(); ++i) { + obs_predicted[i] = decode(shots[i].hits); + } +} diff --git a/src/tesseract_ftl.h b/src/tesseract_ftl.h new file mode 100644 index 0000000..6df0373 --- /dev/null +++ b/src/tesseract_ftl.h @@ -0,0 +1,256 @@ + +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TESSERACT_FTL_DECODER_H +#define TESSERACT_FTL_DECODER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "stim.h" +#include "tesseract.h" +#include "utils.h" +#include "visualization.h" + +constexpr size_t DEFAULT_FTL_SUBSET_DETCOST_SIZE = 0; + +enum class FTLDetectorChoicePolicy : uint8_t { + kOrder = 0, + kFewestIncidentErrors = 1, + kLargestBudget = 2, + kLargestBudgetPerIncident = 3, +}; + +enum class FTLErrorOrderPolicy : uint8_t { + kStatic = 0, + kReducedCost = 1, +}; + +struct TesseractFTLConfig { + stim::DetectorErrorModel dem; + int det_beam = DEFAULT_DET_BEAM; + bool beam_climbing = false; + bool no_revisit_dets = true; + + bool verbose = false; + bool merge_errors = true; + size_t pqlimit = DEFAULT_PQLIMIT; + std::vector> det_orders; + double det_penalty = 0; + bool create_visualization = false; + bool ignore_blocked_errors_in_heuristic = false; + size_t num_min_dets_to_consider = 1; + FTLDetectorChoicePolicy detector_choice_policy = FTLDetectorChoicePolicy::kOrder; + FTLErrorOrderPolicy error_order_policy = FTLErrorOrderPolicy::kStatic; + size_t root_det_order_count = 1; + size_t root_det_order_depth = 0; + size_t exact_child_refine_count = 0; + + // 0 = delegate to the original Tesseract detcost heuristic. + // 1 = use the singleton fractional lower bound implemented in this file. + size_t subset_detcost_size = DEFAULT_FTL_SUBSET_DETCOST_SIZE; + + std::string str(); +}; + +enum class FTLHeuristicSource : uint8_t { kPlain = 0, kProjected = 1, kExact = 2 }; + +struct TesseractFTLStats { + size_t num_pq_pushed = 0; + size_t num_nodes_popped = 0; + size_t max_queue_size = 0; + + size_t heuristic_calls = 0; + size_t plain_heuristic_calls = 0; + size_t projection_heuristic_calls = 0; + size_t exact_refinement_calls = 0; + size_t lp_calls = 0; + size_t lp_reinserts = 0; + size_t projected_nodes_generated = 0; + size_t projected_nodes_refined = 0; + double total_lp_refinement_gain = 0.0; + double max_lp_refinement_gain = 0.0; + double lp_total_seconds = 0.0; + double chain_replay_total_seconds = 0.0; + double component_build_total_seconds = 0.0; + double component_candidate_total_seconds = 0.0; + double component_union_total_seconds = 0.0; + double component_dedup_total_seconds = 0.0; + double component_finalize_total_seconds = 0.0; + double simplex_total_seconds = 0.0; + double projection_total_seconds = 0.0; + size_t component_build_calls = 0; + size_t simplex_calls = 0; + size_t projection_calls = 0; + size_t detector_choice_calls = 0; + size_t error_ordering_calls = 0; + size_t total_active_detectors_popped = 0; + size_t total_root_order_candidates = 0; + size_t total_min_detector_candidates = 0; + size_t total_min_detectors_selected = 0; + size_t total_min_detector_available_errors = 0; + size_t total_min_detector_blocked_errors = 0; + size_t total_child_candidates_considered = 0; + size_t total_children_generated = 0; + size_t total_children_beam_pruned = 0; + size_t total_children_infeasible = 0; + double total_selected_min_detector_budget = 0.0; + size_t exact_child_pre_refinements = 0; + + void clear(); + void accumulate(const TesseractFTLStats& other); +}; + +struct TesseractFTLDecoder { + TesseractFTLConfig config; + Visualizer visualizer; + + explicit TesseractFTLDecoder(TesseractFTLConfig config); + ~TesseractFTLDecoder(); + + // Clears the predicted_errors_buffer and fills it with the decoded errors for + // these detection events. + void decode_to_errors(const std::vector& detections); + + // Clears the predicted_errors_buffer and fills it with the decoded errors for + // these detection events, using a specified detector ordering index. + void decode_to_errors(const std::vector& detections, size_t detector_order, + size_t detector_beam); + + // Returns the bitwise XOR of the observables flipped by the errors in the given array, indexed by + // the original flattened DEM error indices. + std::vector get_flipped_observables(const std::vector& predicted_errors) const; + + // Returns the sum of likelihood costs of the errors in the given array, indexed by the original + // flattened DEM error indices. + double cost_from_errors(const std::vector& predicted_errors) const; + + std::vector decode(const std::vector& detections); + void decode_shots(std::vector& shots, + std::vector>& obs_predicted); + + bool low_confidence_flag = false; + std::vector predicted_errors_buffer; + std::vector dem_error_to_error; + std::vector error_to_dem_error; + std::vector errors; + size_t num_observables = 0; + size_t num_detectors = 0; + TesseractFTLStats stats; + + struct SingletonPatternConstraint { + std::vector local_detectors; + double rhs = 0.0; + }; + + private: + struct ErrorCost { + double likelihood_cost = 0; + double min_cost = 0; + }; + + struct FTLNode { + double f_cost = 0.0; + double g_cost = 0.0; + double h_cost = 0.0; + size_t num_dets = 0; + size_t depth = 0; + int64_t error_chain_idx = -1; + int64_t detector_state_idx = -1; + int64_t warm_solution_idx = -1; + int64_t exact_solution_idx = -1; + bool exact_refined = false; + FTLHeuristicSource heuristic_source = FTLHeuristicSource::kPlain; + + bool operator>(const FTLNode& other) const; + }; + struct SingletonLPComponent { + std::vector detectors; + std::vector constraints; + std::vector cheapest_constraint_for_local_detector; + }; + + struct ExactSubsetSolution { + double value = 0.0; + size_t num_active_subsets = 0; + size_t num_components = 0; + size_t num_variables = 0; + size_t num_constraints = 0; + std::vector active_detectors; + std::vector detector_budgets; + }; + + struct SingletonBuildResult { + bool feasible = true; + std::vector components; + }; + + struct DynamicBitsetHash { + size_t operator()(const boost::dynamic_bitset<>& bs) const; + }; + + std::vector> d2e; + std::vector> edets; + size_t num_errors = 0; + std::vector error_costs; + std::vector error_chain_arena; + std::vector> detector_state_arena; + std::vector exact_solution_arena; + std::unordered_map, int64_t, DynamicBitsetHash> exact_solution_cache; + mutable std::vector candidate_error_marks; + mutable uint64_t candidate_error_mark_epoch = 1; + + // If subset_detcost_size == 0, delegate to the original Tesseract decoder. + std::unique_ptr plain_delegate; + + void initialize_structures(size_t num_detectors); + + void flip_detectors_and_block_errors(size_t detector_order, int64_t error_chain_idx, + boost::dynamic_bitset<>& detectors, + std::vector& blocked_flags) const; + + SingletonBuildResult build_singleton_components(const boost::dynamic_bitset<>& detectors, + const std::vector& blocked_flags); + + ExactSubsetSolution solve_exact_subset_lp(const boost::dynamic_bitset<>& detectors, + const std::vector& blocked_flags, + int64_t warm_solution_idx); + + double project_from_exact_solution(const ExactSubsetSolution& solution, + const boost::dynamic_bitset<>& detectors, + const std::vector& blocked_flags); + + std::vector select_min_detectors(const boost::dynamic_bitset<>& detectors, + const std::vector& blocked_flags, + size_t detector_order, size_t depth, + const ExactSubsetSolution& exact_solution); + + std::vector order_candidate_errors(size_t min_detector, + const boost::dynamic_bitset<>& detectors, + const std::vector& blocked_flags, + const ExactSubsetSolution& exact_solution); + + void reset_decode_state(); +}; + +#endif // TESSERACT_FTL_DECODER_H diff --git a/src/tesseract_ftl_main.cc b/src/tesseract_ftl_main.cc new file mode 100644 index 0000000..92758b0 --- /dev/null +++ b/src/tesseract_ftl_main.cc @@ -0,0 +1,609 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "stim.h" +#include "tesseract_ftl.h" +#include "utils.h" + +namespace { + +FTLDetectorChoicePolicy parse_detector_choice_policy(const std::string& value) { + if (value == "order") return FTLDetectorChoicePolicy::kOrder; + if (value == "fewest_incident_errors") return FTLDetectorChoicePolicy::kFewestIncidentErrors; + if (value == "largest_budget") return FTLDetectorChoicePolicy::kLargestBudget; + if (value == "largest_budget_per_incident") { + return FTLDetectorChoicePolicy::kLargestBudgetPerIncident; + } + throw std::invalid_argument("Unknown detector choice policy: " + value); +} + +FTLErrorOrderPolicy parse_error_order_policy(const std::string& value) { + if (value == "static") return FTLErrorOrderPolicy::kStatic; + if (value == "reduced_cost") return FTLErrorOrderPolicy::kReducedCost; + throw std::invalid_argument("Unknown error order policy: " + value); +} + +} // namespace + +struct Args { + std::string circuit_path; + std::string dem_path; + bool no_merge_errors = false; + + uint64_t det_order_seed; + size_t num_det_orders = 10; + bool det_order_bfs = false; + bool det_order_index = false; + bool det_order_coordinate = false; + + size_t sample_num_shots = 0; + size_t max_errors = SIZE_MAX; + uint64_t sample_seed; + + size_t shot_range_begin = 0; + size_t shot_range_end = 0; + + std::string in_fname = ""; + std::string in_format = ""; + std::string obs_in_fname = ""; + std::string obs_in_format = ""; + bool append_observables = false; + std::string out_fname = ""; + std::string out_format = ""; + + std::string dem_out_fname = ""; + std::string stats_out_fname = ""; + + size_t num_threads = 1; + + size_t det_beam; + double det_penalty = 0; + bool beam_climbing = false; + bool no_revisit_dets = false; + size_t pqlimit; + + size_t subset_detcost_size = 0; + bool ignore_blocked_errors_in_heuristic = false; + size_t num_min_dets_to_consider = 1; + std::string detector_choice_policy = "order"; + std::string error_order_policy = "static"; + size_t root_det_order_count = 1; + size_t root_det_order_depth = 0; + size_t exact_child_refine_count = 0; + + bool verbose = false; + bool print_stats = false; + + bool has_observables() { + return append_observables || !obs_in_fname.empty() || (sample_num_shots > 0); + } + + void validate() { + if (circuit_path.empty() && dem_path.empty()) { + throw std::invalid_argument("Must provide at least one of --circuit or --dem"); + } + int det_order_flags = int(det_order_bfs) + int(det_order_index) + int(det_order_coordinate); + if (det_order_flags > 1) { + throw std::invalid_argument( + "Only one of --det-order-bfs, --det-order-index, or --det-order-coordinate may be set."); + } + int num_data_sources = int(sample_num_shots > 0) + int(!in_fname.empty()); + if (num_data_sources != 1) { + throw std::invalid_argument("Requires exactly 1 source of shots."); + } + if (!in_fname.empty() && in_format.empty()) { + throw std::invalid_argument("If --in is provided, must also specify --in-format."); + } + if (!out_fname.empty() && out_format.empty()) { + throw std::invalid_argument("If --out is provided, must also specify --out-format."); + } + if (!in_format.empty() && !stim::format_name_to_enum_map().contains(in_format)) { + throw std::invalid_argument("Invalid format: " + in_format); + } + if (!obs_in_format.empty() && !stim::format_name_to_enum_map().contains(obs_in_format)) { + throw std::invalid_argument("Invalid format: " + obs_in_format); + } + if (!out_format.empty() && !stim::format_name_to_enum_map().contains(out_format)) { + throw std::invalid_argument("Invalid format: " + out_format); + } + if (!obs_in_fname.empty() && in_fname.empty()) { + throw std::invalid_argument( + "Cannot load observable flips without a corresponding detection event data file."); + } + if (num_threads == 0) { + throw std::invalid_argument("--threads must be at least 1."); + } + if (shot_range_begin || shot_range_end) { + if (shot_range_end < shot_range_begin) { + throw std::invalid_argument("Provided shot range must have end >= begin."); + } + } + if (sample_num_shots > 0 && circuit_path.empty()) { + throw std::invalid_argument("Cannot sample shots without a circuit."); + } + if (beam_climbing && det_beam == INF_DET_BEAM) { + throw std::invalid_argument("Beam climbing requires a finite beam"); + } + if (subset_detcost_size > 1) { + throw std::invalid_argument("This prototype currently supports --subset-detcost-size <= 1"); + } + if (num_min_dets_to_consider == 0) { + throw std::invalid_argument("--num-min-dets-to-consider must be at least 1"); + } + if (root_det_order_count == 0) { + throw std::invalid_argument("--root-det-order-count must be at least 1"); + } + parse_detector_choice_policy(detector_choice_policy); + parse_error_order_policy(error_order_policy); + } + + void extract(TesseractFTLConfig& config, std::vector& shots, + std::unique_ptr& writer) { + stim::Circuit circuit; + if (!circuit_path.empty()) { + FILE* file = fopen(circuit_path.c_str(), "r"); + if (!file) throw std::invalid_argument("Could not open the file: " + circuit_path); + circuit = stim::Circuit::from_file(file); + fclose(file); + } + + if (!dem_path.empty()) { + FILE* file = fopen(dem_path.c_str(), "r"); + if (!file) throw std::invalid_argument("Could not open the file: " + dem_path); + config.dem = stim::DetectorErrorModel::from_file(file); + fclose(file); + } else { + assert(!circuit_path.empty()); + config.dem = stim::ErrorAnalyzer::circuit_to_detector_error_model( + circuit, /*decompose_errors=*/false, /*fold_loops=*/true, + /*allow_gauge_detectors=*/true, + /*approximate_disjoint_errors_threshold=*/1, + /*ignore_decomposition_failures=*/false, + /*block_decomposition_from_introducing_remnant_edges=*/false); + } + + config.merge_errors = !no_merge_errors; + config.subset_detcost_size = subset_detcost_size; + config.ignore_blocked_errors_in_heuristic = ignore_blocked_errors_in_heuristic; + config.num_min_dets_to_consider = num_min_dets_to_consider; + config.detector_choice_policy = parse_detector_choice_policy(detector_choice_policy); + config.error_order_policy = parse_error_order_policy(error_order_policy); + config.root_det_order_count = root_det_order_count; + config.root_det_order_depth = root_det_order_depth; + config.exact_child_refine_count = exact_child_refine_count; + + { + DetOrder order = DetOrder::DetBFS; + if (det_order_index) { + order = DetOrder::DetIndex; + } else if (det_order_coordinate) { + order = DetOrder::DetCoordinate; + } + config.det_orders = build_det_orders(config.dem, num_det_orders, order, det_order_seed); + } + + if (sample_num_shots > 0) { + assert(!circuit_path.empty()); + std::mt19937_64 rng(sample_seed); + size_t num_detectors = circuit.count_detectors(); + const auto [dets, obs] = + stim::sample_batch_detection_events<64>(circuit, sample_num_shots, rng); + stim::simd_bit_table<64> obs_T = obs.transposed(); + shots.resize(sample_num_shots); + for (size_t k = 0; k < sample_num_shots; k++) { + shots[k].obs_mask = obs_T[k]; + for (size_t d = 0; d < num_detectors; d++) { + if (dets[d][k]) shots[k].hits.push_back(d); + } + } + } + + if (!in_fname.empty()) { + FILE* shots_file = fopen(in_fname.c_str(), "r"); + if (!shots_file) throw std::invalid_argument("Could not open the file: " + in_fname); + stim::FileFormatData shots_in_format = stim::format_name_to_enum_map().at(in_format); + auto reader = stim::MeasureRecordReader::make( + shots_file, shots_in_format.id, 0, config.dem.count_detectors(), + append_observables * config.dem.count_observables()); + stim::SparseShot sparse_shot; + sparse_shot.clear(); + while (reader->start_and_read_entire_record(sparse_shot)) { + shots.push_back(sparse_shot); + sparse_shot.clear(); + } + fclose(shots_file); + } + + if (!obs_in_fname.empty()) { + FILE* obs_file = fopen(obs_in_fname.c_str(), "r"); + if (!obs_file) throw std::invalid_argument("Could not open the file: " + obs_in_fname); + stim::FileFormatData obs_format = stim::format_name_to_enum_map().at(obs_in_format); + auto obs_reader = stim::MeasureRecordReader::make( + obs_file, obs_format.id, 0, 0, config.dem.count_observables()); + stim::SparseShot sparse_shot; + sparse_shot.clear(); + size_t num_obs_shots = 0; + while (obs_reader->start_and_read_entire_record(sparse_shot)) { + if (num_obs_shots >= shots.size()) { + throw std::invalid_argument("Shot data ended before obs data."); + } + shots[num_obs_shots].obs_mask = sparse_shot.obs_mask; + sparse_shot.clear(); + ++num_obs_shots; + } + if (num_obs_shots != shots.size()) { + throw std::invalid_argument("Obs data ended before shot data ended."); + } + fclose(obs_file); + } + + if (shot_range_begin || shot_range_end) { + if (shot_range_end > shots.size()) { + throw std::invalid_argument("Shot range end is past end of shots array."); + } + std::vector shots_in_range(shots.begin() + shot_range_begin, + shots.begin() + shot_range_end); + std::swap(shots_in_range, shots); + } + + if (!out_fname.empty()) { + stim::FileFormatData predictions_out_format = stim::format_name_to_enum_map().at(out_format); + FILE* predictions_file = stdout; + if (out_fname != "-") predictions_file = fopen(out_fname.c_str(), "w"); + writer = stim::MeasureRecordWriter::make(predictions_file, predictions_out_format.id); + writer->begin_result_type('L'); + } + + config.det_beam = det_beam; + config.det_penalty = det_penalty; + config.beam_climbing = beam_climbing; + config.no_revisit_dets = no_revisit_dets; + config.pqlimit = pqlimit; + config.verbose = verbose; + } +}; + +int main(int argc, char* argv[]) { + std::cout.precision(16); + argparse::ArgumentParser program("tesseract_ftl"); + Args args; + + program.add_argument("--circuit").help("Stim circuit file path").store_into(args.circuit_path); + program.add_argument("--dem").help("Stim dem file path").store_into(args.dem_path); + program.add_argument("--no-merge-errors") + .help("If provided, will not merge identical error mechanisms.") + .store_into(args.no_merge_errors); + program.add_argument("--subset-detcost-size") + .help("0 = plain detcost delegate, 1 = singleton fractional lower bound") + .default_value(size_t(0)) + .store_into(args.subset_detcost_size); + program.add_argument("--ignore-blocked-errors-in-heuristic") + .help("Experimental: ignore precedence-blocked errors when computing the FTL LP heuristic") + .flag() + .store_into(args.ignore_blocked_errors_in_heuristic); + program.add_argument("--num-min-dets-to-consider") + .help( + "Experimental: when expanding a node, branch on the first N active detectors in the " + "selected detector order.") + .default_value(size_t(1)) + .store_into(args.num_min_dets_to_consider); + program.add_argument("--detector-choice-policy") + .help( + "Experimental detector pivot policy: order, fewest_incident_errors, " + "largest_budget, or largest_budget_per_incident.") + .default_value(std::string("order")) + .store_into(args.detector_choice_policy); + program.add_argument("--error-order-policy") + .help("Experimental sibling ordering policy: static or reduced_cost.") + .default_value(std::string("static")) + .store_into(args.error_order_policy); + program.add_argument("--root-det-order-count") + .help("Experimental: at shallow depths, union candidates from the first N detector orders.") + .default_value(size_t(1)) + .store_into(args.root_det_order_count); + program.add_argument("--root-det-order-depth") + .help("Experimental: use root-det-order-count while node depth is less than this value.") + .default_value(size_t(0)) + .store_into(args.root_det_order_depth); + program.add_argument("--exact-child-refine-count") + .help( + "Experimental exact mode: immediately LP-refine the first N generated children per " + "expanded node.") + .default_value(size_t(0)) + .store_into(args.exact_child_refine_count); + + program.add_argument("--num-det-orders") + .help("Number of ways to orient the manifold when reordering the detectors") + .metavar("N") + .default_value(size_t(1)) + .store_into(args.num_det_orders); + program.add_argument("--det-order-bfs") + .help("Use BFS-based detector ordering") + .flag() + .store_into(args.det_order_bfs); + program.add_argument("--det-order-index") + .help("Randomly choose increasing or decreasing detector index order") + .flag() + .store_into(args.det_order_index); + program.add_argument("--det-order-coordinate") + .help("Random geometric detector orientation ordering") + .flag() + .store_into(args.det_order_coordinate); + program.add_argument("--det-order-seed") + .help("Seed used when initializing the random detector traversal orderings.") + .default_value(static_cast(518278944)) + .store_into(args.det_order_seed); + + program.add_argument("--sample-num-shots") + .help("Sample the requested number of shots from the Stim circuit.") + .store_into(args.sample_num_shots); + program.add_argument("--max-errors") + .help("Stop after at least this many errors have been observed.") + .store_into(args.max_errors); + program.add_argument("--sample-seed") + .help("Seed used when initializing the random number generator for sampling shots") + .default_value(static_cast(std::random_device()())) + .store_into(args.sample_seed); + + program.add_argument("--shot-range-begin") + .default_value(size_t(0)) + .store_into(args.shot_range_begin); + program.add_argument("--shot-range-end").default_value(size_t(0)).store_into(args.shot_range_end); + + program.add_argument("--in").default_value(std::string("")).store_into(args.in_fname); + std::string in_formats; + bool first = true; + for (const auto& [key, value] : stim::format_name_to_enum_map()) { + if (!first) in_formats += "/"; + first = false; + in_formats += key; + } + program.add_argument("--in-format", "--in_format") + .default_value(std::string("")) + .store_into(args.in_format); + program.add_argument("--in-includes-appended-observables", "--in_includes_appended_observables") + .default_value(false) + .store_into(args.append_observables) + .flag(); + program.add_argument("--obs_in", "--obs-in") + .default_value(std::string("")) + .store_into(args.obs_in_fname); + program.add_argument("--obs-in-format", "--obs_in_format") + .default_value(std::string("")) + .store_into(args.obs_in_format); + program.add_argument("--out").default_value(std::string("")).store_into(args.out_fname); + program.add_argument("--out-format").default_value(std::string("")).store_into(args.out_format); + program.add_argument("--dem-out").default_value(std::string("")).store_into(args.dem_out_fname); + program.add_argument("--stats-out") + .default_value(std::string("")) + .store_into(args.stats_out_fname); + + program.add_argument("--threads") + .default_value(size_t( + std::thread::hardware_concurrency() == 0 ? 1 : std::thread::hardware_concurrency())) + .store_into(args.num_threads); + program.add_argument("--beam").default_value(INF_DET_BEAM).store_into(args.det_beam); + program.add_argument("--det-penalty").default_value(0.0).store_into(args.det_penalty); + program.add_argument("--beam-climbing").flag().store_into(args.beam_climbing); + program.add_argument("--no-revisit-dets").flag().store_into(args.no_revisit_dets); + program.add_argument("--pqlimit") + .default_value(std::numeric_limits::max()) + .store_into(args.pqlimit); + program.add_argument("--verbose").flag().store_into(args.verbose); + program.add_argument("--print-stats").flag().store_into(args.print_stats); + + try { + program.parse_args(argc, argv); + } catch (const std::exception& err) { + std::cerr << err.what() << std::endl; + std::cerr << program; + return EXIT_FAILURE; + } + args.validate(); + + TesseractFTLConfig config; + std::vector shots; + std::unique_ptr writer; + args.extract(config, shots, writer); + + std::vector obs_predicted(shots.size()); + std::vector cost_predicted(shots.size()); + std::vector decoding_time_seconds(shots.size()); + std::vector> low_confidence(shots.size()); + const stim::DetectorErrorModel original_dem = config.dem.flattened(); + std::vector> decoders(args.num_threads); + std::vector> error_use_per_thread( + args.num_threads, std::vector(original_dem.count_errors())); + std::vector decoder_stats_per_thread(args.num_threads); + + bool has_obs = args.has_observables(); + size_t num_errors = 0; + size_t num_low_confidence = 0; + double total_time_seconds = 0; + size_t num_observables = config.dem.count_observables(); + + size_t shot = parallel_for_shots_in_order( + shots.size(), args.num_threads, + [&](size_t thread_index, size_t shot_index) { + if (!decoders[thread_index]) { + decoders[thread_index] = std::make_unique(config); + } + auto& decoder = *decoders[thread_index]; + auto& error_use = error_use_per_thread[thread_index]; + auto start_time = std::chrono::high_resolution_clock::now(); + decoder.decode_to_errors(shots[shot_index].hits); + auto stop_time = std::chrono::high_resolution_clock::now(); + decoding_time_seconds[shot_index] = + std::chrono::duration_cast(stop_time - start_time).count() / + 1e6; + obs_predicted[shot_index] = + vector_to_u64_mask(decoder.get_flipped_observables(decoder.predicted_errors_buffer)); + low_confidence[shot_index] = decoder.low_confidence_flag; + cost_predicted[shot_index] = decoder.cost_from_errors(decoder.predicted_errors_buffer); + decoder_stats_per_thread[thread_index].accumulate(decoder.stats); + if (!has_obs || shots[shot_index].obs_mask_as_u64() == obs_predicted[shot_index]) { + for (size_t ei : decoder.predicted_errors_buffer) ++error_use[ei]; + } + }, + [&](size_t shot_index) { + if (writer) { + writer->write_bits((uint8_t*)&obs_predicted[shot_index], num_observables); + writer->write_end(); + } + if (low_confidence[shot_index]) { + ++num_low_confidence; + } else if (obs_predicted[shot_index] != shots[shot_index].obs_mask_as_u64()) { + ++num_errors; + } + total_time_seconds += decoding_time_seconds[shot_index]; + if (args.print_stats) { + std::cout << "num_shots = " << (shot_index + 1) + << " num_low_confidence = " << num_low_confidence + << " num_errors = " << num_errors + << " total_time_seconds = " << total_time_seconds << std::endl; + std::cout << "cost = " << cost_predicted[shot_index] << std::endl; + std::cout.flush(); + } + return num_errors < args.max_errors; + }); + + std::vector error_use_totals(original_dem.count_errors()); + for (const auto& error_use : error_use_per_thread) { + for (size_t ei = 0; ei < error_use_totals.size(); ++ei) error_use_totals[ei] += error_use[ei]; + } + TesseractFTLStats decoder_stats_total; + for (const auto& s : decoder_stats_per_thread) decoder_stats_total.accumulate(s); + + if (!args.dem_out_fname.empty()) { + size_t num_usage_dem_shots = shot; + if (has_obs) num_usage_dem_shots -= num_errors; + stim::DetectorErrorModel est_dem = + common::dem_from_counts(original_dem, error_use_totals, num_usage_dem_shots); + std::ofstream out(args.dem_out_fname, std::ofstream::out); + if (!out.is_open()) throw std::invalid_argument("Failed to open " + args.dem_out_fname); + out << est_dem << '\n'; + } + + bool print_final_stats = true; + if (!args.stats_out_fname.empty()) { + nlohmann::json stats_json = { + {"circuit_path", args.circuit_path}, + {"dem_path", args.dem_path}, + {"max_errors", args.max_errors}, + {"sample_seed", args.sample_seed}, + {"det_beam", args.det_beam}, + {"det_penalty", args.det_penalty}, + {"beam_climbing", args.beam_climbing}, + {"no_revisit_dets", args.no_revisit_dets}, + {"pqlimit", args.pqlimit}, + {"num_det_orders", args.num_det_orders}, + {"det_order_seed", args.det_order_seed}, + {"subset_detcost_size", args.subset_detcost_size}, + {"ignore_blocked_errors_in_heuristic", args.ignore_blocked_errors_in_heuristic}, + {"num_min_dets_to_consider", args.num_min_dets_to_consider}, + {"detector_choice_policy", args.detector_choice_policy}, + {"error_order_policy", args.error_order_policy}, + {"root_det_order_count", args.root_det_order_count}, + {"root_det_order_depth", args.root_det_order_depth}, + {"exact_child_refine_count", args.exact_child_refine_count}, + {"total_time_seconds", total_time_seconds}, + {"num_errors", num_errors}, + {"num_low_confidence", num_low_confidence}, + {"num_shots", shot}, + {"num_threads", args.num_threads}, + {"sample_num_shots", args.sample_num_shots}, + {"ftl_num_pq_pushed", decoder_stats_total.num_pq_pushed}, + {"ftl_num_nodes_popped", decoder_stats_total.num_nodes_popped}, + {"ftl_max_queue_size", decoder_stats_total.max_queue_size}, + {"ftl_heuristic_calls", decoder_stats_total.heuristic_calls}, + {"ftl_plain_heuristic_calls", decoder_stats_total.plain_heuristic_calls}, + {"ftl_projection_heuristic_calls", decoder_stats_total.projection_heuristic_calls}, + {"ftl_exact_refinement_calls", decoder_stats_total.exact_refinement_calls}, + {"ftl_lp_calls", decoder_stats_total.lp_calls}, + {"ftl_lp_reinserts", decoder_stats_total.lp_reinserts}, + {"ftl_projected_nodes_generated", decoder_stats_total.projected_nodes_generated}, + {"ftl_projected_nodes_refined", decoder_stats_total.projected_nodes_refined}, + {"ftl_total_lp_refinement_gain", decoder_stats_total.total_lp_refinement_gain}, + {"ftl_max_lp_refinement_gain", decoder_stats_total.max_lp_refinement_gain}, + {"ftl_lp_total_seconds", decoder_stats_total.lp_total_seconds}, + {"ftl_chain_replay_total_seconds", decoder_stats_total.chain_replay_total_seconds}, + {"ftl_component_build_total_seconds", decoder_stats_total.component_build_total_seconds}, + {"ftl_component_candidate_total_seconds", + decoder_stats_total.component_candidate_total_seconds}, + {"ftl_component_union_total_seconds", decoder_stats_total.component_union_total_seconds}, + {"ftl_component_dedup_total_seconds", decoder_stats_total.component_dedup_total_seconds}, + {"ftl_component_finalize_total_seconds", + decoder_stats_total.component_finalize_total_seconds}, + {"ftl_simplex_total_seconds", decoder_stats_total.simplex_total_seconds}, + {"ftl_projection_total_seconds", decoder_stats_total.projection_total_seconds}, + {"ftl_component_build_calls", decoder_stats_total.component_build_calls}, + {"ftl_simplex_calls", decoder_stats_total.simplex_calls}, + {"ftl_projection_calls", decoder_stats_total.projection_calls}, + {"ftl_detector_choice_calls", decoder_stats_total.detector_choice_calls}, + {"ftl_error_ordering_calls", decoder_stats_total.error_ordering_calls}, + {"ftl_total_active_detectors_popped", decoder_stats_total.total_active_detectors_popped}, + {"ftl_total_root_order_candidates", decoder_stats_total.total_root_order_candidates}, + {"ftl_total_min_detector_candidates", decoder_stats_total.total_min_detector_candidates}, + {"ftl_total_min_detectors_selected", decoder_stats_total.total_min_detectors_selected}, + {"ftl_total_min_detector_available_errors", + decoder_stats_total.total_min_detector_available_errors}, + {"ftl_total_min_detector_blocked_errors", + decoder_stats_total.total_min_detector_blocked_errors}, + {"ftl_total_child_candidates_considered", + decoder_stats_total.total_child_candidates_considered}, + {"ftl_total_children_generated", decoder_stats_total.total_children_generated}, + {"ftl_total_children_beam_pruned", decoder_stats_total.total_children_beam_pruned}, + {"ftl_total_children_infeasible", decoder_stats_total.total_children_infeasible}, + {"ftl_total_selected_min_detector_budget", + decoder_stats_total.total_selected_min_detector_budget}, + {"ftl_exact_child_pre_refinements", decoder_stats_total.exact_child_pre_refinements}, + }; + + if (args.stats_out_fname == "-") { + std::cout << stats_json << std::endl; + print_final_stats = false; + } else { + std::ofstream out(args.stats_out_fname, std::ofstream::out); + out << stats_json << std::endl; + } + } + + if (print_final_stats) { + std::cout << "num_shots = " << shot; + std::cout << " num_low_confidence = " << num_low_confidence; + if (has_obs) std::cout << " num_errors = " << num_errors; + std::cout << " total_time_seconds = " << total_time_seconds; + if (args.subset_detcost_size > 0) { + std::cout << " lp_calls = " << decoder_stats_total.lp_calls; + std::cout << " lp_reinserts = " << decoder_stats_total.lp_reinserts; + std::cout << " projected_nodes_generated = " << decoder_stats_total.projected_nodes_generated; + std::cout << " projected_nodes_refined = " << decoder_stats_total.projected_nodes_refined; + std::cout << " child_candidates = " << decoder_stats_total.total_child_candidates_considered; + std::cout << " children_generated = " << decoder_stats_total.total_children_generated; + } + std::cout << std::endl; + } + return 0; +} diff --git a/src/tesseract_main.cc b/src/tesseract_main.cc index ab5ed9c..ff7b3d0 100644 --- a/src/tesseract_main.cc +++ b/src/tesseract_main.cc @@ -483,6 +483,8 @@ int main(int argc, char* argv[]) { std::vector obs_predicted(shots.size()); std::vector cost_predicted(shots.size()); std::vector decoding_time_seconds(shots.size()); + std::vector num_pq_pushed_per_shot(shots.size()); + std::vector num_pq_popped_per_shot(shots.size()); std::vector> low_confidence(shots.size()); const stim::DetectorErrorModel original_dem = config.dem.flattened(); std::vector> decoders(args.num_threads); @@ -511,6 +513,8 @@ int main(int argc, char* argv[]) { vector_to_u64_mask(decoder.get_flipped_observables(decoder.predicted_errors_buffer)); low_confidence[shot_index] = decoder.low_confidence_flag; cost_predicted[shot_index] = decoder.cost_from_errors(decoder.predicted_errors_buffer); + num_pq_pushed_per_shot[shot_index] = decoder.num_pq_pushed; + num_pq_popped_per_shot[shot_index] = decoder.num_pq_popped; if (!has_obs or shots[shot_index].obs_mask_as_u64() == obs_predicted[shot_index]) { for (size_t ei : decoder.predicted_errors_buffer) { ++error_use[ei]; @@ -532,6 +536,8 @@ int main(int argc, char* argv[]) { std::cout << "num_shots = " << (shot_index + 1) << " num_low_confidence = " << num_low_confidence << " num_errors = " << num_errors + << " num_pq_pushed = " << num_pq_pushed_per_shot[shot_index] + << " num_pq_popped = " << num_pq_popped_per_shot[shot_index] << " total_time_seconds = " << total_time_seconds << std::endl; std::cout << "cost = " << cost_predicted[shot_index] << std::endl; std::cout.flush(); diff --git a/src/tesseract_trellis.cc b/src/tesseract_trellis.cc new file mode 100644 index 0000000..89dbe35 --- /dev/null +++ b/src/tesseract_trellis.cc @@ -0,0 +1,1073 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tesseract_trellis.h" + +#include +#include +#include +#include +#include +#include +#if defined(__BMI2__) && \ + (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) +#include +#endif +#include +#include +#include +#include +#include + +#include "utils.h" + +struct TesseractTrellisWideKernelBase { + virtual ~TesseractTrellisWideKernelBase() = default; + virtual void decode_shot(TesseractTrellisDecoder* decoder, + const std::vector& detections) const = 0; +}; + +namespace { + +constexpr size_t kMaxCompiledWideStateWords = 4; + +#if defined(__GNUC__) || defined(__clang__) +#define TESSERACT_ALWAYS_INLINE inline __attribute__((always_inline)) +#else +#define TESSERACT_ALWAYS_INLINE inline +#endif + +struct Fault { + size_t error_index; + double likelihood_cost; + double log_q; + double log_p; + uint64_t obs_mask; + std::vector detectors; +}; + +template +using FixedWideStateWords = std::array; + +template +struct FixedWideStateEntry { + FixedWideStateWords state_words{}; + double mass0 = 0.0; + double mass1 = 0.0; + double penalty = 0.0; + double score = -INF; +}; + +template +struct FixedWidePairBucket { + FixedWideStateWords key{}; + double mass0[2]{}; + double mass1[2]{}; + double penalty[2]{}; + uint8_t used_mask = 0; + bool occupied = false; +}; + +template +struct CompiledWideLayerTemplate { + double q = 0.0; + double p = 0.0; + bool toggles_observable = false; + bool has_retiring_terms = false; + size_t surviving_term_count = 0; + std::array surviving_masks{}; + std::array projection_dst_words{}; + std::array projection_dst_offsets{}; + std::array projected_fault_mask_words{}; + std::vector fault_target_word_indices; + std::vector fault_target_bit_masks; + std::vector fault_word_indices; + std::vector fault_bit_masks; + std::vector fault_was_active_before; + std::vector current_costs; + std::vector next_costs; +}; + +struct BranchPenaltyUpdate { + bool absent_valid = true; + bool present_valid = true; + double absent_penalty = 0.0; + double present_penalty = 0.0; +}; + +struct FinalizeKeptStateStatsOnExit { + TesseractTrellisDecoder* decoder; + + ~FinalizeKeptStateStatsOnExit(); +}; + +std::vector parse_faults(const std::vector& errors, size_t num_observables) { + std::vector faults; + faults.reserve(errors.size()); + for (size_t error_index = 0; error_index < errors.size(); ++error_index) { + const auto& error = errors[error_index]; + const double p = error.get_probability(); + if (p <= 0) { + continue; + } + Fault fault; + fault.error_index = error_index; + fault.likelihood_cost = error.likelihood_cost; + fault.log_q = std::log1p(-p); + fault.log_p = std::log(p); + fault.obs_mask = 0; + for (int obs : error.symptom.observables) { + if (obs >= 64) { + throw std::invalid_argument("tesseract_trellis currently supports at most 64 observables"); + } + if (size_t(obs) >= num_observables) { + throw std::invalid_argument("Observable index out of range in DEM"); + } + fault.obs_mask ^= uint64_t{1} << obs; + } + fault.detectors = error.symptom.detectors; + faults.push_back(std::move(fault)); + } + return faults; +} + +void build_wide_layer_templates(const std::vector& faults, size_t num_detectors, + std::vector* layers, + size_t* max_frontier_width_seen) { + std::vector last_seen(num_detectors, std::numeric_limits::max()); + for (size_t i = 0; i < faults.size(); ++i) { + for (int d : faults[i].detectors) { + last_seen[d] = i; + } + } + + std::vector active_detectors; + active_detectors.reserve(num_detectors); + std::vector global_to_local(num_detectors, -1); + layers->clear(); + layers->reserve(faults.size()); + *max_frontier_width_seen = 0; + + for (size_t i = 0; i < faults.size(); ++i) { + const size_t previous_width = active_detectors.size(); + for (int d : faults[i].detectors) { + if (global_to_local[d] == -1) { + global_to_local[d] = active_detectors.size(); + active_detectors.push_back(d); + } + } + + *max_frontier_width_seen = std::max(*max_frontier_width_seen, active_detectors.size()); + TesseractTrellisWideLayerTemplate layer{ + .q = std::exp(faults[i].log_q), + .p = std::exp(faults[i].log_p), + .obs_mask = faults[i].obs_mask, + .previous_width = previous_width, + .surviving_local_indices = {}, + .current_active_detectors = active_detectors, + .projected_fault_mask_words = {}, + .next_frontier_costs = {}, + .detcost_transition = {}, + }; + + for (size_t local = 0; local < active_detectors.size(); ++local) { + const int d = active_detectors[local]; + if (last_seen[d] != i) { + layer.surviving_local_indices.push_back((uint32_t)local); + } + } + + std::vector next_active; + next_active.reserve(layer.surviving_local_indices.size()); + std::fill(global_to_local.begin(), global_to_local.end(), -1); + for (size_t next_local = 0; next_local < layer.surviving_local_indices.size(); ++next_local) { + int d = active_detectors[layer.surviving_local_indices[next_local]]; + global_to_local[d] = next_local; + next_active.push_back(d); + } + active_detectors = std::move(next_active); + layers->push_back(std::move(layer)); + } +} + +template +void build_future_detcost_transitions(const std::vector& faults, size_t num_detectors, + std::vector* layers, + std::vector* initial_future_detcost) { + std::vector current_row(num_detectors, INF); + for (size_t fault_index = faults.size(); fault_index-- > 0;) { + auto& layer = (*layers)[fault_index]; + const auto& fault = faults[fault_index]; + + layer.next_frontier_costs.resize(layer.surviving_local_indices.size(), INF); + for (size_t next_local = 0; next_local < layer.surviving_local_indices.size(); ++next_local) { + size_t current_local = (size_t)layer.surviving_local_indices[next_local]; + int global_detector = layer.current_active_detectors[current_local]; + layer.next_frontier_costs[next_local] = current_row[(size_t)global_detector]; + } + + std::vector current_to_next(layer.current_active_detectors.size(), -1); + for (size_t next_local = 0; next_local < layer.surviving_local_indices.size(); ++next_local) { + current_to_next[(size_t)layer.surviving_local_indices[next_local]] = (int32_t)next_local; + } + + auto& transition = layer.detcost_transition; + transition.fault_local_indices.clear(); + transition.next_local_indices.clear(); + transition.current_costs.clear(); + transition.next_costs.clear(); + transition.fault_local_indices.reserve(fault.detectors.size()); + transition.next_local_indices.reserve(fault.detectors.size()); + transition.current_costs.reserve(fault.detectors.size()); + transition.next_costs.reserve(fault.detectors.size()); + + if (!fault.detectors.empty()) { + double ecost = fault.likelihood_cost / fault.detectors.size(); + for (int detector : fault.detectors) { + auto it = std::find(layer.current_active_detectors.begin(), + layer.current_active_detectors.end(), detector); + if (it == layer.current_active_detectors.end()) { + throw std::runtime_error("Missing detector in active frontier while preparing detcost."); + } + uint32_t local = (uint32_t)std::distance(layer.current_active_detectors.begin(), it); + double next_cost = current_row[(size_t)detector]; + double current_cost = std::min(ecost, next_cost); + transition.fault_local_indices.push_back(local); + transition.next_local_indices.push_back(current_to_next[local]); + transition.current_costs.push_back(current_cost); + transition.next_costs.push_back(next_cost); + current_row[(size_t)detector] = current_cost; + } + } + } + + if (initial_future_detcost != nullptr) { + *initial_future_detcost = std::move(current_row); + } +} + +size_t num_state_words(size_t num_bits) { + return (num_bits + 63) >> 6; +} + +TESSERACT_ALWAYS_INLINE size_t detector_word_index(size_t detector) { + return detector >> 6; +} + +TESSERACT_ALWAYS_INLINE uint64_t detector_word_mask(size_t detector) { + return uint64_t{1} << (detector & 63); +} + +TESSERACT_ALWAYS_INLINE uint64_t compact_bits_u64(uint64_t value, uint64_t mask) { +#if defined(__BMI2__) && \ + (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) + return _pext_u64(value, mask); +#else + uint64_t out = 0; + uint64_t out_bit = 1; + while (mask) { + uint64_t keep = mask & -mask; + if (value & keep) { + out |= out_bit; + } + mask ^= keep; + out_bit <<= 1; + } + return out; +#endif +} + +double compute_initial_penalty_for_active_detectors( + const std::vector& active_detector_word_indices, + const std::vector& active_detector_bit_masks, + const std::vector& active_detector_costs, + const std::vector& actual_detector_words) { + double total = 0.0; + for (size_t k = 0; k < active_detector_costs.size(); ++k) { + if ((actual_detector_words[active_detector_word_indices[k]] & active_detector_bit_masks[k]) == + 0) { + continue; + } + double best = active_detector_costs[k]; + if (best == INF) { + return INF; + } + total += best; + } + return total; +} + +double score_mass_and_penalty(double mass, double penalty, + TesseractTrellisRankingMode ranking_mode) { + if (ranking_mode == TesseractTrellisRankingMode::MassOnly) { + return mass; + } + if (penalty == INF || mass == 0.0) { + return -INF; + } + return std::log(mass) - penalty; +} + +template +TESSERACT_ALWAYS_INLINE double total_entry_mass(const FixedWideStateEntry& entry) { + return entry.mass0 + entry.mass1; +} + +TESSERACT_ALWAYS_INLINE uint64_t mix_splitmix64(uint64_t value) { + value += 0x9e3779b97f4a7c15ULL; + value = (value ^ (value >> 30)) * 0xbf58476d1ce4e5b9ULL; + value = (value ^ (value >> 27)) * 0x94d049bb133111ebULL; + return value ^ (value >> 31); +} + +void reset_kept_state_stats(TesseractTrellisDecoder* decoder) { + decoder->kept_state_sample_count = 0; + decoder->kept_state_min = 0; + decoder->kept_state_median = 0; + decoder->kept_state_mean = 0; + decoder->kept_state_max = 0; + if (!decoder->config.track_kept_state_stats) { + return; + } + + const size_t histogram_size = decoder->config.beam_width + 1; + if (decoder->kept_state_histogram_scratch.size() != histogram_size) { + decoder->kept_state_histogram_scratch.assign(histogram_size, 0); + } else { + std::fill(decoder->kept_state_histogram_scratch.begin(), + decoder->kept_state_histogram_scratch.end(), 0); + } +} + +void record_kept_state_count(TesseractTrellisDecoder* decoder, size_t kept_states) { + if (!decoder->config.track_kept_state_stats) { + return; + } + + kept_states = std::min(kept_states, decoder->config.beam_width); + if (decoder->kept_state_sample_count == 0) { + decoder->kept_state_min = kept_states; + decoder->kept_state_max = kept_states; + } else { + decoder->kept_state_min = std::min(decoder->kept_state_min, kept_states); + decoder->kept_state_max = std::max(decoder->kept_state_max, kept_states); + } + ++decoder->kept_state_sample_count; + decoder->kept_state_mean += kept_states; + ++decoder->kept_state_histogram_scratch[kept_states]; +} + +void finalize_kept_state_stats(TesseractTrellisDecoder* decoder) { + if (!decoder->config.track_kept_state_stats || decoder->kept_state_sample_count == 0) { + return; + } + + decoder->kept_state_mean /= decoder->kept_state_sample_count; + const size_t lower_target = (decoder->kept_state_sample_count - 1) >> 1; + const size_t upper_target = decoder->kept_state_sample_count >> 1; + size_t seen = 0; + size_t lower = 0; + size_t upper = 0; + bool lower_found = false; + for (size_t kept_states = 0; kept_states < decoder->kept_state_histogram_scratch.size(); + ++kept_states) { + seen += decoder->kept_state_histogram_scratch[kept_states]; + if (!lower_found && seen > lower_target) { + lower = kept_states; + lower_found = true; + } + if (seen > upper_target) { + upper = kept_states; + break; + } + } + decoder->kept_state_median = 0.5 * (lower + upper); +} + +FinalizeKeptStateStatsOnExit::~FinalizeKeptStateStatsOnExit() { + finalize_kept_state_stats(decoder); +} + +void prepare_projected_fault_masks(std::vector* layers) { + for (auto& layer : *layers) { + layer.projected_fault_mask_words.assign(num_state_words(layer.surviving_local_indices.size()), + 0); + for (int32_t next_local : layer.detcost_transition.next_local_indices) { + if (next_local >= 0) { + size_t local = (size_t)next_local; + layer.projected_fault_mask_words[local >> 6] ^= uint64_t{1} << (local & 63); + } + } + } +} + +template +bool fixed_wide_state_less(const FixedWideStateWords& a, const FixedWideStateWords& b) { + for (size_t k = Words; k-- > 0;) { + if (a[k] != b[k]) { + return a[k] < b[k]; + } + } + return false; +} + +template +bool fixed_wide_state_zero(const FixedWideStateWords& state_words) { + for (size_t k = 0; k < Words; ++k) { + if (state_words[k] != 0) { + return false; + } + } + return true; +} + +template +void xor_compiled_wide_state(FixedWideStateWords* state_words, + const std::array& mask_words) { + for (size_t k = 0; k < Words; ++k) { + (*state_words)[k] ^= mask_words[k]; + } +} + +template +TESSERACT_ALWAYS_INLINE uint64_t hash_fixed_wide_state(const FixedWideStateWords& state_words) { + uint64_t hash = 0x123456789abcdef0ULL; + for (size_t k = 0; k < Words; ++k) { + hash ^= mix_splitmix64(state_words[k] + 0x9e3779b97f4a7c15ULL * (k + 1)); + hash = std::rotl(hash, 21); + } + return hash; +} + +template +void ensure_pair_bucket_capacity(std::vector>* buckets, + size_t num_parents) { + const size_t required = std::bit_ceil(std::max(16, num_parents * 2)); + if (buckets->size() < required) { + buckets->resize(required); + } +} + +template +void clear_pair_buckets(std::vector>* buckets, + std::vector* used_bucket_indices) { + for (size_t index : *used_bucket_indices) { + (*buckets)[index].occupied = false; + (*buckets)[index].used_mask = 0; + } + used_bucket_indices->clear(); +} + +template +TESSERACT_ALWAYS_INLINE size_t find_or_insert_pair_bucket( + std::vector>* buckets, std::vector* used_bucket_indices, + const FixedWideStateWords& key) { + const size_t mask = buckets->size() - 1; + size_t index = hash_fixed_wide_state(key) & mask; + while ((*buckets)[index].occupied) { + if ((*buckets)[index].key == key) { + return index; + } + index = (index + 1) & mask; + } + + auto& bucket = (*buckets)[index]; + bucket.occupied = true; + bucket.key = key; + bucket.used_mask = 0; + used_bucket_indices->push_back(index); + return index; +} + +template +TESSERACT_ALWAYS_INLINE void accumulate_pair_bucket_slot(FixedWidePairBucket* bucket, + uint8_t slot, double mass0, double mass1, + double penalty) { + const uint8_t bit = (uint8_t)(1u << slot); + if ((bucket->used_mask & bit) == 0) { + bucket->mass0[slot] = mass0; + bucket->mass1[slot] = mass1; + bucket->penalty[slot] = penalty; + bucket->used_mask |= bit; + } else { + bucket->mass0[slot] += mass0; + bucket->mass1[slot] += mass1; + } +} + +template +FixedWideStateWords project_compiled_wide_state( + const FixedWideStateWords& state_words, const CompiledWideLayerTemplate& layer) { + FixedWideStateWords out{}; + for (size_t src_word = 0; src_word < Words; ++src_word) { + const uint64_t mask = layer.surviving_masks[src_word]; + if (mask == 0) { + continue; + } + const uint64_t packed = compact_bits_u64(state_words[src_word], mask); + const size_t dst_word = layer.projection_dst_words[src_word]; + const uint8_t shift = layer.projection_dst_offsets[src_word]; + out[dst_word] |= packed << shift; + if constexpr (Words > 1) { + if (shift != 0 && dst_word + 1 < Words) { + out[dst_word + 1] |= packed >> (64 - shift); + } + } + } + return out; +} + +template +TESSERACT_ALWAYS_INLINE BranchPenaltyUpdate compute_compiled_wide_branch_update( + const FixedWideStateWords& base_state_words, double current_penalty, + const std::vector& actual_detector_words, const CompiledWideLayerTemplate& layer) { + BranchPenaltyUpdate update; + update.absent_penalty = ComputePenalties ? current_penalty : 0.0; + update.present_penalty = ComputePenalties ? current_penalty : 0.0; + + for (size_t k = 0; k < layer.surviving_term_count; ++k) { + const bool state_bit = + layer.fault_was_active_before[k] && + ((base_state_words[layer.fault_word_indices[k]] & layer.fault_bit_masks[k]) != 0); + const bool target_bit = + (actual_detector_words[layer.fault_target_word_indices[k]] & + layer.fault_target_bit_masks[k]) != 0; + const bool mismatch = state_bit ^ target_bit; + + if constexpr (ComputePenalties) { + const double prev_contrib = + (layer.fault_was_active_before[k] && mismatch) ? layer.current_costs[k] : 0.0; + const double next_contrib = mismatch ? layer.next_costs[k] : 0.0; + update.absent_penalty += next_contrib - prev_contrib; + update.present_penalty += (layer.next_costs[k] - next_contrib) - prev_contrib; + } + } + + if constexpr (CheckRetiringTerms) { + for (size_t k = layer.surviving_term_count; k < layer.fault_target_word_indices.size(); ++k) { + const bool state_bit = + layer.fault_was_active_before[k] && + ((base_state_words[layer.fault_word_indices[k]] & layer.fault_bit_masks[k]) != 0); + const bool target_bit = + (actual_detector_words[layer.fault_target_word_indices[k]] & + layer.fault_target_bit_masks[k]) != 0; + const bool mismatch = state_bit ^ target_bit; + + if (mismatch) { + update.absent_valid = false; + } else { + update.present_valid = false; + } + + if constexpr (ComputePenalties) { + const double prev_contrib = + (layer.fault_was_active_before[k] && mismatch) ? layer.current_costs[k] : 0.0; + update.absent_penalty -= prev_contrib; + update.present_penalty -= prev_contrib; + } + } + } + + return update; +} + +template +void expand_compiled_layer_into_pair_buckets( + const std::vector>& beam_entries, + std::vector>* pair_buckets, std::vector* used_bucket_indices, + const std::vector& actual_detector_words, const CompiledWideLayerTemplate& layer, + TesseractTrellisDecoder* decoder) { + for (const auto& item : beam_entries) { + ++decoder->num_states_expanded; + BranchPenaltyUpdate update = compute_compiled_wide_branch_update( + item.state_words, item.penalty, actual_detector_words, layer); + + if (!update.absent_valid && !update.present_valid) { + continue; + } + + FixedWideStateWords projected_state = project_compiled_wide_state(item.state_words, layer); + FixedWideStateWords projected_toggled = projected_state; + xor_compiled_wide_state(&projected_toggled, layer.projected_fault_mask_words); + const bool projected_is_key = !fixed_wide_state_less(projected_toggled, projected_state); + const auto& bucket_key = projected_is_key ? projected_state : projected_toggled; + const uint8_t absent_slot = projected_is_key ? 0 : 1; + const uint8_t present_slot = projected_toggled == bucket_key ? 0 : 1; + const size_t bucket_index = + find_or_insert_pair_bucket(pair_buckets, used_bucket_indices, bucket_key); + auto& bucket = (*pair_buckets)[bucket_index]; + const bool keep_absent = update.absent_valid && layer.q != 0.0; + const bool keep_present = update.present_valid && layer.p != 0.0; + + if (keep_absent) { + accumulate_pair_bucket_slot(&bucket, absent_slot, item.mass0 * layer.q, item.mass1 * layer.q, + update.absent_penalty); + } + if (keep_present) { + if (layer.toggles_observable) { + accumulate_pair_bucket_slot(&bucket, present_slot, item.mass1 * layer.p, item.mass0 * layer.p, + update.present_penalty); + } else { + accumulate_pair_bucket_slot(&bucket, present_slot, item.mass0 * layer.p, item.mass1 * layer.p, + update.present_penalty); + } + } + } +} + +template +void normalize_compiled_items(std::vector>* items) { + double total = 0.0; + for (const auto& item : *items) { + total += total_entry_mass(item); + } + if (total == 0.0) { + return; + } + for (auto& item : *items) { + item.mass0 /= total; + item.mass1 /= total; + } +} + +template +void merge_equal_compiled_keys_inplace(std::vector>* items) { + if (items->empty()) { + return; + } + std::sort(items->begin(), items->end(), + [](const FixedWideStateEntry& a, const FixedWideStateEntry& b) { + return fixed_wide_state_less(a.state_words, b.state_words); + }); + + size_t out = 0; + for (size_t i = 1; i < items->size(); ++i) { + if ((*items)[i].state_words == (*items)[out].state_words) { + (*items)[out].mass0 += (*items)[i].mass0; + (*items)[out].mass1 += (*items)[i].mass1; + } else { + ++out; + if (out != i) { + (*items)[out] = std::move((*items)[i]); + } + } + } + items->resize(out + 1); +} + +template +bool compiled_state_score_greater(const FixedWideStateEntry& a, + const FixedWideStateEntry& b) { + if (a.score != b.score) { + return a.score > b.score; + } + return fixed_wide_state_less(a.state_words, b.state_words); +} + +template +size_t keep_top_compiled_states(std::vector>* entries, + size_t beam_width, double beam_eps, + TesseractTrellisRankingMode ranking_mode) { + if (entries->empty()) { + return 0; + } + + double total_mass = 0.0; + for (auto& entry : *entries) { + const double mass = total_entry_mass(entry); + entry.score = score_mass_and_penalty(mass, entry.penalty, ranking_mode); + if (beam_eps > 0.0) { + total_mass += mass; + } + } + + if (entries->size() > beam_width) { + std::nth_element(entries->begin(), entries->begin() + beam_width, entries->end(), + [](const FixedWideStateEntry& a, const FixedWideStateEntry& b) { + return compiled_state_score_greater(a, b); + }); + entries->resize(beam_width); + } else if (beam_eps <= 0.0) { + return entries->size(); + } + + if (beam_eps <= 0.0 || total_mass <= 0.0) { + return entries->size(); + } + + std::sort(entries->begin(), entries->end(), + [](const FixedWideStateEntry& a, const FixedWideStateEntry& b) { + return compiled_state_score_greater(a, b); + }); + const double retained_target_mass = total_mass * (1.0 - beam_eps); + double retained_mass = 0.0; + size_t keep_count = 0; + while (keep_count < entries->size()) { + retained_mass += total_entry_mass((*entries)[keep_count]); + ++keep_count; + if (retained_mass >= retained_target_mass) { + break; + } + } + entries->resize(keep_count); + return keep_count; +} + +template +std::vector> compile_wide_layers( + const std::vector& layers) { + std::vector> compiled_layers; + compiled_layers.reserve(layers.size()); + for (const auto& layer : layers) { + if (num_state_words(layer.current_active_detectors.size()) > Words || + layer.projected_fault_mask_words.size() > Words) { + throw std::invalid_argument("Compiled wide kernel word count is smaller than the frontier."); + } + + CompiledWideLayerTemplate compiled; + compiled.q = layer.q; + compiled.p = layer.p; + if (layer.obs_mask > 1) { + throw std::invalid_argument("tesseract_trellis currently supports at most 1 observable"); + } + compiled.toggles_observable = layer.obs_mask != 0; + + std::array surviving_masks{}; + for (uint32_t current_local : layer.surviving_local_indices) { + surviving_masks[current_local >> 6] |= uint64_t{1} << (current_local & 63); + } + size_t next_offset = 0; + for (size_t src_word = 0; src_word < Words; ++src_word) { + compiled.surviving_masks[src_word] = surviving_masks[src_word]; + compiled.projection_dst_words[src_word] = static_cast(next_offset >> 6); + compiled.projection_dst_offsets[src_word] = static_cast(next_offset & 63); + next_offset += std::popcount(surviving_masks[src_word]); + } + + for (size_t k = 0; k < layer.projected_fault_mask_words.size(); ++k) { + compiled.projected_fault_mask_words[k] = layer.projected_fault_mask_words[k]; + } + + const auto& transition = layer.detcost_transition; + const size_t term_count = transition.fault_local_indices.size(); + compiled.fault_target_word_indices.reserve(term_count); + compiled.fault_target_bit_masks.reserve(term_count); + compiled.fault_word_indices.reserve(term_count); + compiled.fault_bit_masks.reserve(term_count); + compiled.fault_was_active_before.reserve(term_count); + compiled.current_costs.reserve(term_count); + compiled.next_costs.reserve(term_count); + auto append_term = [&](size_t idx) { + const uint32_t local = transition.fault_local_indices[idx]; + const uint32_t detector = (uint32_t)layer.current_active_detectors[local]; + compiled.fault_target_word_indices.push_back((uint32_t)detector_word_index(detector)); + compiled.fault_target_bit_masks.push_back(detector_word_mask(detector)); + compiled.fault_word_indices.push_back(static_cast(local >> 6)); + compiled.fault_bit_masks.push_back(uint64_t{1} << (local & 63)); + compiled.fault_was_active_before.push_back(local < layer.previous_width); + compiled.current_costs.push_back(transition.current_costs[idx]); + compiled.next_costs.push_back(transition.next_costs[idx]); + }; + for (size_t idx = 0; idx < term_count; ++idx) { + if (transition.next_local_indices[idx] >= 0) { + append_term(idx); + } + } + compiled.surviving_term_count = compiled.fault_target_word_indices.size(); + compiled.has_retiring_terms = compiled.surviving_term_count != term_count; + for (size_t idx = 0; idx < term_count; ++idx) { + if (transition.next_local_indices[idx] < 0) { + append_term(idx); + } + } + + compiled_layers.push_back(std::move(compiled)); + } + return compiled_layers; +} + +template +struct CompiledWideKernel final : TesseractTrellisWideKernelBase { + explicit CompiledWideKernel(std::vector> layers_, + std::vector initial_detector_word_indices_, + std::vector initial_detector_bit_masks_, + std::vector initial_detector_costs_, + size_t max_frontier_width_) + : layers(std::move(layers_)), + initial_detector_word_indices(std::move(initial_detector_word_indices_)), + initial_detector_bit_masks(std::move(initial_detector_bit_masks_)), + initial_detector_costs(std::move(initial_detector_costs_)), + max_frontier_width(max_frontier_width_) {} + + void decode_shot(TesseractTrellisDecoder* decoder, + const std::vector& detections) const override { + auto& actual_detector_words = decoder->actual_detector_words_scratch; + std::fill(actual_detector_words.begin(), actual_detector_words.end(), 0); + for (uint64_t d : detections) { + if (d >= decoder->num_detectors) { + decoder->low_confidence_flag = true; + return; + } + const size_t word = detector_word_index((size_t)d); + const uint64_t mask = detector_word_mask((size_t)d); + if ((decoder->all_possible_detector_words[word] & mask) == 0) { + decoder->low_confidence_flag = true; + return; + } + actual_detector_words[word] ^= mask; + } + + decoder->max_frontier_width_seen = max_frontier_width; + + double initial_penalty = 0.0; + if (decoder->config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked && + !layers.empty()) { + initial_penalty = compute_initial_penalty_for_active_detectors(initial_detector_word_indices, + initial_detector_bit_masks, + initial_detector_costs, + actual_detector_words); + } + + std::vector> beam_entries; + std::vector> next_entries; + std::vector> pair_buckets; + std::vector used_bucket_indices; + beam_entries.reserve(decoder->config.beam_width * 2 + 2); + next_entries.reserve(decoder->config.beam_width * 4 + 4); + beam_entries.push_back({{}, 1.0, 0.0, initial_penalty}); + decoder->max_beam_size_seen = 1; + + const bool compute_penalties = + decoder->config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked; + for (size_t layer_index = 0; layer_index < layers.size(); ++layer_index) { + const auto& layer = layers[layer_index]; + + ensure_pair_bucket_capacity(&pair_buckets, beam_entries.size()); + clear_pair_buckets(&pair_buckets, &used_bucket_indices); + + auto t0 = std::chrono::high_resolution_clock::now(); + + if (decoder->config.verbose) { + std::cout << "expanding layer " << layer_index << " / " << (layers.size() - 1) + << std::endl; + std::cout << "states to expand = " << beam_entries.size() << std::endl; + } + if (compute_penalties) { + if (layer.has_retiring_terms) { + expand_compiled_layer_into_pair_buckets( + beam_entries, &pair_buckets, &used_bucket_indices, actual_detector_words, layer, + decoder); + } else { + expand_compiled_layer_into_pair_buckets( + beam_entries, &pair_buckets, &used_bucket_indices, actual_detector_words, layer, + decoder); + } + } else if (layer.has_retiring_terms) { + expand_compiled_layer_into_pair_buckets( + beam_entries, &pair_buckets, &used_bucket_indices, actual_detector_words, layer, + decoder); + } else { + expand_compiled_layer_into_pair_buckets( + beam_entries, &pair_buckets, &used_bucket_indices, actual_detector_words, layer, + decoder); + } + auto t1 = std::chrono::high_resolution_clock::now(); + decoder->time_expand_seconds += + std::chrono::duration_cast(t1 - t0).count() / 1e6; + + auto t2a = std::chrono::high_resolution_clock::now(); + next_entries.clear(); + next_entries.reserve(used_bucket_indices.size() * 2); + for (size_t index : used_bucket_indices) { + auto& bucket = pair_buckets[index]; + if ((bucket.used_mask & 1u) != 0) { + next_entries.push_back({bucket.key, bucket.mass0[0], bucket.mass1[0], bucket.penalty[0]}); + } + if ((bucket.used_mask & 2u) != 0) { + auto other_state = bucket.key; + xor_compiled_wide_state(&other_state, layer.projected_fault_mask_words); + next_entries.push_back( + {std::move(other_state), bucket.mass0[1], bucket.mass1[1], bucket.penalty[1]}); + } + } + beam_entries.swap(next_entries); + auto t2 = std::chrono::high_resolution_clock::now(); + decoder->time_collapse_seconds += + std::chrono::duration_cast(t2 - t2a).count() / 1e6; + + const size_t kept_states = keep_top_compiled_states( + &beam_entries, decoder->config.beam_width, decoder->config.beam_eps, + decoder->config.ranking_mode); + normalize_compiled_items(&beam_entries); + record_kept_state_count(decoder, beam_entries.empty() ? 0 : kept_states); + if (beam_entries.empty()) { + decoder->low_confidence_flag = true; + return; + } + decoder->num_states_merged += kept_states; + decoder->max_beam_size_seen = std::max(decoder->max_beam_size_seen, kept_states); + auto t3 = std::chrono::high_resolution_clock::now(); + decoder->time_truncate_seconds += + std::chrono::duration_cast(t3 - t2).count() / 1e6; + } + + auto tr0 = std::chrono::high_resolution_clock::now(); + for (const auto& item : beam_entries) { + if (!fixed_wide_state_zero(item.state_words)) { + continue; + } + decoder->total_mass_obs0 += item.mass0; + decoder->total_mass_obs1 += item.mass1; + } + if (decoder->total_mass_obs0 == 0.0 && decoder->total_mass_obs1 == 0.0) { + decoder->low_confidence_flag = true; + return; + } + decoder->predicted_obs_mask = decoder->total_mass_obs1 > decoder->total_mass_obs0 ? 1 : 0; + auto tr1 = std::chrono::high_resolution_clock::now(); + decoder->time_reconstruct_seconds += + std::chrono::duration_cast(tr1 - tr0).count() / 1e6; + } + + std::vector> layers; + std::vector initial_detector_word_indices; + std::vector initial_detector_bit_masks; + std::vector initial_detector_costs; + size_t max_frontier_width; +}; + +std::unique_ptr build_compiled_wide_kernel( + const std::vector& layers, size_t max_frontier_width, + const std::vector& initial_future_detcost) { + const size_t required_words = std::max(1, num_state_words(max_frontier_width)); + if (required_words > kMaxCompiledWideStateWords) { + throw std::invalid_argument("Wide trellis frontier requires " + std::to_string(required_words) + + " words, but only " + + std::to_string(kMaxCompiledWideStateWords) + + " compiled words are enabled."); + } + + std::vector initial_detector_word_indices; + std::vector initial_detector_bit_masks; + std::vector initial_detector_costs; + if (!layers.empty()) { + const auto& initial_active_detectors = layers.front().current_active_detectors; + initial_detector_word_indices.reserve(initial_active_detectors.size()); + initial_detector_bit_masks.reserve(initial_active_detectors.size()); + initial_detector_costs.reserve(initial_active_detectors.size()); + for (int detector : initial_active_detectors) { + initial_detector_word_indices.push_back((uint32_t)detector_word_index((size_t)detector)); + initial_detector_bit_masks.push_back(detector_word_mask((size_t)detector)); + initial_detector_costs.push_back(initial_future_detcost[(size_t)detector]); + } + } + switch (required_words) { + case 1: + return std::make_unique>( + compile_wide_layers<1>(layers), initial_detector_word_indices, initial_detector_bit_masks, + initial_detector_costs, max_frontier_width); + case 2: + return std::make_unique>( + compile_wide_layers<2>(layers), initial_detector_word_indices, initial_detector_bit_masks, + initial_detector_costs, max_frontier_width); + case 3: + return std::make_unique>( + compile_wide_layers<3>(layers), initial_detector_word_indices, initial_detector_bit_masks, + initial_detector_costs, max_frontier_width); + case 4: + return std::make_unique>( + compile_wide_layers<4>(layers), initial_detector_word_indices, initial_detector_bit_masks, + initial_detector_costs, max_frontier_width); + default: + throw std::invalid_argument("Unsupported compiled wide trellis word count."); + } +} + +} // namespace + +TesseractTrellisDecoder::~TesseractTrellisDecoder() = default; + +TesseractTrellisDecoder::TesseractTrellisDecoder(TesseractTrellisConfig config_) + : config(std::move(config_)) { + std::vector dem_error_map(config.dem.flattened().count_errors()); + std::iota(dem_error_map.begin(), dem_error_map.end(), 0); + dem_error_to_error = std::move(dem_error_map); + error_to_dem_error = common::invert_error_map(dem_error_to_error, config.dem.count_errors()); + errors = get_errors_from_dem(config.dem.flattened()); + num_detectors = config.dem.count_detectors(); + num_observables = config.dem.count_observables(); + if (num_observables > 1) { + throw std::invalid_argument("tesseract_trellis currently supports at most 1 observable"); + } + + all_possible_detector_words.assign(num_state_words(num_detectors), 0); + actual_detector_words_scratch.assign(all_possible_detector_words.size(), 0); + for (const auto& error : errors) { + for (int d : error.symptom.detectors) { + all_possible_detector_words[detector_word_index((size_t)d)] |= + detector_word_mask((size_t)d); + } + } + + auto faults = parse_faults(errors, num_observables); + + size_t wide_frontier_width = 0; + build_wide_layer_templates(faults, num_detectors, &wide_layer_templates, &wide_frontier_width); + std::vector initial_future_detcost; + build_future_detcost_transitions(faults, num_detectors, &wide_layer_templates, + &initial_future_detcost); + prepare_projected_fault_masks(&wide_layer_templates); + wide_kernel = + build_compiled_wide_kernel(wide_layer_templates, wide_frontier_width, initial_future_detcost); +} + +__attribute__((hot)) void TesseractTrellisDecoder::decode_shot( + const std::vector& detections) { + low_confidence_flag = false; + num_states_expanded = 0; + num_states_merged = 0; + max_beam_size_seen = 0; + max_frontier_width_seen = 0; + reset_kept_state_stats(this); + time_expand_seconds = 0; + time_collapse_seconds = 0; + time_truncate_seconds = 0; + time_reconstruct_seconds = 0; + predicted_obs_mask = 0; + total_mass_obs0 = 0; + total_mass_obs1 = 0; + FinalizeKeptStateStatsOnExit kept_state_stats_guard{this}; + wide_kernel->decode_shot(this, detections); + + if (config.verbose) { + std::cout << "trellis beam_width=" << config.beam_width + << " frontier_width=" << max_frontier_width_seen + << " states_expanded=" << num_states_expanded + << " states_merged=" << num_states_merged << " max_beam=" << max_beam_size_seen + << std::endl; + } +} + +std::vector TesseractTrellisDecoder::decode(const std::vector& detections) { + decode_shot(detections); + return predicted_obs_mask ? std::vector{0} : std::vector{}; +} + +void TesseractTrellisDecoder::decode_shots(std::vector& shots, + std::vector>& obs_predicted) { + obs_predicted.resize(shots.size()); + for (size_t i = 0; i < shots.size(); ++i) { + obs_predicted[i] = decode(shots[i].hits); + } +} diff --git a/src/tesseract_trellis.h b/src/tesseract_trellis.h new file mode 100644 index 0000000..d3b6ee0 --- /dev/null +++ b/src/tesseract_trellis.h @@ -0,0 +1,100 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TESSERACT_TRELLIS_DECODER_H +#define TESSERACT_TRELLIS_DECODER_H + +#include +#include +#include + +#include "common.h" +#include "stim.h" + +struct TesseractTrellisWideKernelBase; + +enum class TesseractTrellisRankingMode { + MassOnly, + FutureDetcostRanked, +}; + +struct TesseractTrellisDetcostTransition { + std::vector fault_local_indices; + std::vector next_local_indices; + std::vector current_costs; + std::vector next_costs; +}; + +struct TesseractTrellisWideLayerTemplate { + double q = 0; + double p = 0; + uint64_t obs_mask = 0; + size_t previous_width = 0; + std::vector surviving_local_indices; + std::vector current_active_detectors; + std::vector projected_fault_mask_words; + std::vector next_frontier_costs; + TesseractTrellisDetcostTransition detcost_transition; +}; + +struct TesseractTrellisConfig { + stim::DetectorErrorModel dem; + size_t beam_width = 1024; + double beam_eps = 0.0; + bool verbose = false; + bool track_kept_state_stats = false; + TesseractTrellisRankingMode ranking_mode = TesseractTrellisRankingMode::MassOnly; +}; + +struct TesseractTrellisDecoder { + explicit TesseractTrellisDecoder(TesseractTrellisConfig config); + ~TesseractTrellisDecoder(); + + void decode_shot(const std::vector& detections); + std::vector decode(const std::vector& detections); + void decode_shots(std::vector& shots, + std::vector>& obs_predicted); + + TesseractTrellisConfig config; + bool low_confidence_flag = false; + size_t num_states_expanded = 0; + size_t num_states_merged = 0; + size_t max_beam_size_seen = 0; + size_t max_frontier_width_seen = 0; + size_t kept_state_sample_count = 0; + size_t kept_state_min = 0; + double kept_state_median = 0; + double kept_state_mean = 0; + size_t kept_state_max = 0; + double time_expand_seconds = 0; + double time_collapse_seconds = 0; + double time_truncate_seconds = 0; + double time_reconstruct_seconds = 0; + uint64_t predicted_obs_mask = 0; + double total_mass_obs0 = 0; + double total_mass_obs1 = 0; + + std::vector dem_error_to_error; + std::vector error_to_dem_error; + std::vector errors; + size_t num_observables = 0; + size_t num_detectors = 0; + std::vector all_possible_detector_words; + std::vector actual_detector_words_scratch; + std::vector wide_layer_templates; + std::unique_ptr wide_kernel; + std::vector kept_state_histogram_scratch; +}; + +#endif // TESSERACT_TRELLIS_DECODER_H diff --git a/src/tesseract_trellis_main.cc b/src/tesseract_trellis_main.cc new file mode 100644 index 0000000..28e2386 --- /dev/null +++ b/src/tesseract_trellis_main.cc @@ -0,0 +1,429 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "stim.h" +#include "tesseract_trellis.h" +#include "utils.h" + +namespace { + +TesseractTrellisRankingMode parse_ranking_mode(const std::string& value) { + if (value == "mass") return TesseractTrellisRankingMode::MassOnly; + if (value == "future-detcost") return TesseractTrellisRankingMode::FutureDetcostRanked; + throw std::invalid_argument("Unknown trellis ranking mode: " + value); +} + +} // namespace + +struct Args { + std::string circuit_path; + std::string dem_path; + + size_t sample_num_shots = 0; + size_t max_errors = SIZE_MAX; + uint64_t sample_seed; + + size_t shot_range_begin = 0; + size_t shot_range_end = 0; + + std::string in_fname = ""; + std::string in_format = ""; + std::string obs_in_fname = ""; + std::string obs_in_format = ""; + bool append_observables = false; + std::string out_fname = ""; + std::string out_format = ""; + + std::string dem_out_fname = ""; + std::string stats_out_fname = ""; + + size_t num_threads = 1; + size_t beam_width = 1024; + double beam_eps = 0.0; + std::string ranking_mode = "mass"; + + bool verbose = false; + bool print_stats = false; + + bool has_observables() { + return append_observables || !obs_in_fname.empty() || (sample_num_shots > 0); + } + + void validate() { + if (circuit_path.empty() && dem_path.empty()) { + throw std::invalid_argument("Must provide at least one of --circuit or --dem"); + } + int num_data_sources = int(sample_num_shots > 0) + int(!in_fname.empty()); + if (num_data_sources != 1) { + throw std::invalid_argument("Requires exactly 1 source of shots."); + } + if (!in_fname.empty() && in_format.empty()) { + throw std::invalid_argument("If --in is provided, must also specify --in-format."); + } + if (!out_fname.empty() && out_format.empty()) { + throw std::invalid_argument("If --out is provided, must also specify --out-format."); + } + if (!in_format.empty() && !stim::format_name_to_enum_map().contains(in_format)) { + throw std::invalid_argument("Invalid format: " + in_format); + } + if (!obs_in_format.empty() && !stim::format_name_to_enum_map().contains(obs_in_format)) { + throw std::invalid_argument("Invalid format: " + obs_in_format); + } + if (!out_format.empty() && !stim::format_name_to_enum_map().contains(out_format)) { + throw std::invalid_argument("Invalid format: " + out_format); + } + if (!obs_in_fname.empty() && in_fname.empty()) { + throw std::invalid_argument( + "Cannot load observable flips without a corresponding detection event data file."); + } + if (num_threads == 0) { + throw std::invalid_argument("--threads must be at least 1."); + } + if (num_threads > 1000) { + throw std::invalid_argument("There is a maximum limit of 1000 threads."); + } + if ((shot_range_begin || shot_range_end) && shot_range_end < shot_range_begin) { + throw std::invalid_argument("Provided shot range must have end >= begin."); + } + if (sample_num_shots > 0 && circuit_path.empty()) { + throw std::invalid_argument("Cannot sample shots without a circuit."); + } + if (beam_width == 0) { + throw std::invalid_argument("--beam must be at least 1."); + } + if (!std::isfinite(beam_eps) || beam_eps < 0.0 || beam_eps >= 1.0) { + throw std::invalid_argument("--beam-eps must satisfy 0 <= beam-eps < 1."); + } + parse_ranking_mode(ranking_mode); + } + + void extract(TesseractTrellisConfig& config, std::vector& shots, + std::unique_ptr& writer) { + stim::Circuit circuit; + if (!circuit_path.empty()) { + FILE* file = fopen(circuit_path.c_str(), "r"); + if (!file) { + throw std::invalid_argument("Could not open the file: " + circuit_path); + } + circuit = stim::Circuit::from_file(file); + fclose(file); + } + + if (!dem_path.empty()) { + FILE* file = fopen(dem_path.c_str(), "r"); + if (!file) { + throw std::invalid_argument("Could not open the file: " + dem_path); + } + config.dem = stim::DetectorErrorModel::from_file(file); + fclose(file); + } else { + assert(!circuit_path.empty()); + config.dem = stim::ErrorAnalyzer::circuit_to_detector_error_model( + circuit, /*decompose_errors=*/false, /*fold_loops=*/true, + /*allow_gauge_detectors=*/true, + /*approximate_disjoint_errors_threshold=*/1, + /*ignore_decomposition_failures=*/false, + /*block_decomposition_from_introducing_remnant_edges=*/false); + } + + config.beam_width = beam_width; + config.beam_eps = beam_eps; + config.verbose = verbose; + config.track_kept_state_stats = print_stats; + config.ranking_mode = parse_ranking_mode(ranking_mode); + + if (sample_num_shots > 0) { + assert(!circuit_path.empty()); + std::mt19937_64 rng(sample_seed); + size_t num_detectors = circuit.count_detectors(); + const auto [dets, obs] = + stim::sample_batch_detection_events<64>(circuit, sample_num_shots, rng); + stim::simd_bit_table<64> obs_T = obs.transposed(); + shots.resize(sample_num_shots); + for (size_t k = 0; k < sample_num_shots; k++) { + shots[k].obs_mask = obs_T[k]; + for (size_t d = 0; d < num_detectors; d++) { + if (dets[d][k]) { + shots[k].hits.push_back(d); + } + } + } + } + + if (!in_fname.empty()) { + FILE* shots_file = fopen(in_fname.c_str(), "r"); + if (!shots_file) { + throw std::invalid_argument("Could not open the file: " + in_fname); + } + stim::FileFormatData shots_in_format = stim::format_name_to_enum_map().at(in_format); + auto reader = stim::MeasureRecordReader::make( + shots_file, shots_in_format.id, 0, config.dem.count_detectors(), + append_observables * config.dem.count_observables()); + stim::SparseShot sparse_shot; + sparse_shot.clear(); + while (reader->start_and_read_entire_record(sparse_shot)) { + shots.push_back(sparse_shot); + sparse_shot.clear(); + } + fclose(shots_file); + } + + if (!obs_in_fname.empty()) { + FILE* obs_file = fopen(obs_in_fname.c_str(), "r"); + if (!obs_file) { + throw std::invalid_argument("Could not open the file: " + obs_in_fname); + } + stim::FileFormatData obs_format = stim::format_name_to_enum_map().at(obs_in_format); + auto obs_reader = stim::MeasureRecordReader::make( + obs_file, obs_format.id, 0, 0, config.dem.count_observables()); + stim::SparseShot sparse_shot; + sparse_shot.clear(); + size_t num_obs_shots = 0; + while (obs_reader->start_and_read_entire_record(sparse_shot)) { + if (num_obs_shots >= shots.size()) { + throw std::invalid_argument("Shot data ended before obs data."); + } + shots[num_obs_shots].obs_mask = sparse_shot.obs_mask; + sparse_shot.clear(); + ++num_obs_shots; + } + if (num_obs_shots != shots.size()) { + throw std::invalid_argument("Obs data ended before shot data ended."); + } + fclose(obs_file); + } + + if (shot_range_begin || shot_range_end) { + if (shot_range_end > shots.size()) { + throw std::invalid_argument("Shot range end is past end of shots array."); + } + std::vector shots_in_range(shots.begin() + shot_range_begin, + shots.begin() + shot_range_end); + std::swap(shots_in_range, shots); + } + + if (!out_fname.empty()) { + stim::FileFormatData predictions_out_format = stim::format_name_to_enum_map().at(out_format); + FILE* predictions_file = stdout; + if (out_fname != "-") { + predictions_file = fopen(out_fname.c_str(), "w"); + } + writer = stim::MeasureRecordWriter::make(predictions_file, predictions_out_format.id); + writer->begin_result_type('L'); + } + } +}; + +int main(int argc, char* argv[]) { + std::cout.precision(16); + argparse::ArgumentParser program("tesseract_trellis"); + Args args; + program.add_argument("--circuit").help("Stim circuit file path").store_into(args.circuit_path); + program.add_argument("--dem").help("Stim dem file path").store_into(args.dem_path); + program.add_argument("--sample-num-shots").store_into(args.sample_num_shots); + program.add_argument("--max-errors").store_into(args.max_errors); + program.add_argument("--sample-seed") + .default_value(static_cast(std::random_device()())) + .store_into(args.sample_seed); + program.add_argument("--shot-range-begin") + .default_value(size_t(0)) + .store_into(args.shot_range_begin); + program.add_argument("--shot-range-end").default_value(size_t(0)).store_into(args.shot_range_end); + program.add_argument("--in").default_value(std::string("")).store_into(args.in_fname); + program.add_argument("--in-format", "--in_format") + .default_value(std::string("")) + .store_into(args.in_format); + program.add_argument("--in-includes-appended-observables", "--in_includes_appended_observables") + .default_value(false) + .store_into(args.append_observables) + .flag(); + program.add_argument("--obs_in", "--obs-in") + .default_value(std::string("")) + .store_into(args.obs_in_fname); + program.add_argument("--obs-in-format", "--obs_in_format") + .default_value(std::string("")) + .store_into(args.obs_in_format); + program.add_argument("--out").default_value(std::string("")).store_into(args.out_fname); + program.add_argument("--out-format").default_value(std::string("")).store_into(args.out_format); + program.add_argument("--dem-out").default_value(std::string("")).store_into(args.dem_out_fname); + program.add_argument("--stats-out") + .default_value(std::string("")) + .store_into(args.stats_out_fname); + program.add_argument("--threads") + .default_value(size_t( + std::thread::hardware_concurrency() == 0 ? 1 : std::thread::hardware_concurrency())) + .store_into(args.num_threads); + program.add_argument("--beam").default_value(size_t(1024)).store_into(args.beam_width); + program.add_argument("--beam-eps") + .help( + "Keep at most --beam merged states and also drop the suffix once the kept prefix has " + "accumulated at least (1 - beam-eps) of the total merged-state mass. Use 0 to disable " + "the mass-threshold cutoff.") + .default_value(0.0) + .store_into(args.beam_eps); + program.add_argument("--ranking-mode") + .help("Trellis ranking mode: mass or future-detcost") + .default_value(std::string("mass")) + .store_into(args.ranking_mode); + program.add_argument("--verbose").flag().store_into(args.verbose); + program.add_argument("--print-stats").flag().store_into(args.print_stats); + + try { + program.parse_args(argc, argv); + } catch (const std::exception& err) { + std::cerr << err.what() << std::endl; + std::cerr << program; + return EXIT_FAILURE; + } + + args.validate(); + TesseractTrellisConfig config; + std::vector shots; + std::unique_ptr writer; + args.extract(config, shots, writer); + + std::vector obs_predicted(shots.size()); + std::vector mass0_predicted(shots.size()); + std::vector mass1_predicted(shots.size()); + std::vector decoding_time_seconds(shots.size()); + std::vector num_states_expanded_per_shot(shots.size()); + std::vector num_states_merged_per_shot(shots.size()); + std::vector max_beam_size_per_shot(shots.size()); + std::vector max_frontier_width_per_shot(shots.size()); + std::vector kept_state_min_per_shot(shots.size()); + std::vector kept_state_median_per_shot(shots.size()); + std::vector kept_state_mean_per_shot(shots.size()); + std::vector kept_state_max_per_shot(shots.size()); + std::vector time_expand_per_shot(shots.size()); + std::vector time_collapse_per_shot(shots.size()); + std::vector time_truncate_per_shot(shots.size()); + std::vector time_reconstruct_per_shot(shots.size()); + std::vector> low_confidence(shots.size()); + const stim::DetectorErrorModel original_dem = config.dem.flattened(); + std::vector> decoders(args.num_threads); + + bool has_obs = args.has_observables(); + size_t num_errors = 0; + size_t num_low_confidence = 0; + double total_time_seconds = 0; + size_t num_observables = config.dem.count_observables(); + + size_t shot = parallel_for_shots_in_order( + shots.size(), args.num_threads, + [&](size_t thread_index, size_t shot_index) { + if (!decoders[thread_index]) { + decoders[thread_index] = std::make_unique(config); + } + auto& decoder = *decoders[thread_index]; + auto start_time = std::chrono::high_resolution_clock::now(); + decoder.decode_shot(shots[shot_index].hits); + auto stop_time = std::chrono::high_resolution_clock::now(); + decoding_time_seconds[shot_index] = + std::chrono::duration_cast(stop_time - start_time).count() / + 1e6; + obs_predicted[shot_index] = decoder.predicted_obs_mask; + low_confidence[shot_index] = decoder.low_confidence_flag; + mass0_predicted[shot_index] = decoder.total_mass_obs0; + mass1_predicted[shot_index] = decoder.total_mass_obs1; + num_states_expanded_per_shot[shot_index] = decoder.num_states_expanded; + num_states_merged_per_shot[shot_index] = decoder.num_states_merged; + max_beam_size_per_shot[shot_index] = decoder.max_beam_size_seen; + max_frontier_width_per_shot[shot_index] = decoder.max_frontier_width_seen; + kept_state_min_per_shot[shot_index] = decoder.kept_state_min; + kept_state_median_per_shot[shot_index] = decoder.kept_state_median; + kept_state_mean_per_shot[shot_index] = decoder.kept_state_mean; + kept_state_max_per_shot[shot_index] = decoder.kept_state_max; + time_expand_per_shot[shot_index] = decoder.time_expand_seconds; + time_collapse_per_shot[shot_index] = decoder.time_collapse_seconds; + time_truncate_per_shot[shot_index] = decoder.time_truncate_seconds; + time_reconstruct_per_shot[shot_index] = decoder.time_reconstruct_seconds; + }, + [&](size_t shot_index) { + if (writer) { + writer->write_bits((uint8_t*)&obs_predicted[shot_index], num_observables); + writer->write_end(); + } + if (low_confidence[shot_index]) { + ++num_low_confidence; + } else if (obs_predicted[shot_index] != shots[shot_index].obs_mask_as_u64()) { + ++num_errors; + } + total_time_seconds += decoding_time_seconds[shot_index]; + if (args.print_stats) { + std::cout << "num_shots = " << (shot_index + 1) + << " num_low_confidence = " << num_low_confidence + << " num_errors = " << num_errors + << " states_expanded = " << num_states_expanded_per_shot[shot_index] + << " states_merged = " << num_states_merged_per_shot[shot_index] + << " max_beam = " << max_beam_size_per_shot[shot_index] + << " frontier_width = " << max_frontier_width_per_shot[shot_index] + << " total_time_seconds = " << total_time_seconds << '\n'; + std::cout << "kept_states" << " min=" << kept_state_min_per_shot[shot_index] + << " median=" << kept_state_median_per_shot[shot_index] + << " mean=" << kept_state_mean_per_shot[shot_index] + << " max=" << kept_state_max_per_shot[shot_index] << '\n'; + std::cout << "branch_masses" << " obs0=" << mass0_predicted[shot_index] + << " obs1=" << mass1_predicted[shot_index] << '\n'; + std::cout << "phase_times_seconds" << " expand=" << time_expand_per_shot[shot_index] + << " collapse=" << time_collapse_per_shot[shot_index] + << " truncate=" << time_truncate_per_shot[shot_index] + << " reconstruct=" << time_reconstruct_per_shot[shot_index] << '\n'; + } + return num_errors < args.max_errors; + }); + + if (!args.dem_out_fname.empty()) { + throw std::invalid_argument( + "--dem-out is not supported by tesseract_trellis without path reconstruction."); + } + + bool print_final_stats = true; + if (!args.stats_out_fname.empty()) { + nlohmann::json stats_json = {{"circuit_path", args.circuit_path}, + {"dem_path", args.dem_path}, + {"beam_width", args.beam_width}, + {"beam_eps", args.beam_eps}, + {"sample_seed", args.sample_seed}, + {"sample_num_shots", args.sample_num_shots}, + {"num_threads", args.num_threads}, + {"num_errors", num_errors}, + {"num_low_confidence", num_low_confidence}, + {"num_shots", shot}, + {"total_time_seconds", total_time_seconds}}; + if (args.stats_out_fname == "-") { + std::cout << stats_json << std::endl; + print_final_stats = false; + } else { + std::ofstream out(args.stats_out_fname, std::ofstream::out); + out << stats_json << std::endl; + } + } + + if (print_final_stats) { + std::cout << "num_shots = " << shot << " num_low_confidence = " << num_low_confidence; + if (has_obs) { + std::cout << " num_errors = " << num_errors; + } + std::cout << " total_time_seconds = " << total_time_seconds << std::endl; + } +}