forked from ionq-publications/BeamSearchDecoder
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsinter_beamsearch.py
More file actions
144 lines (128 loc) · 5.77 KB
/
sinter_beamsearch.py
File metadata and controls
144 lines (128 loc) · 5.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# This file is part of BeamSearchDecoder.
# Copyright (c) 2025 IonQ, Inc., all rights reserved
# Licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
# 4.0 International License (CC BY-NC-SA 4.0).
# You may obtain a copy of the License at:
# https://creativecommons.org/licenses/by-nc-sa/4.0/
import pathlib
from typing import Dict
import numpy as np
import stim
from beamsearch import BeamSearch
from sinter import CompiledDecoder, Decoder
class SinterCompiledDecoder_BeamSearch(CompiledDecoder):
def __init__(self, decoder: "BeamSearch"):
self.decoder = decoder
def decode_shots_bit_packed(
self,
*,
bit_packed_detection_event_data: "np.ndarray",
) -> "np.ndarray":
return self.decoder.decode_batch(
shots=bit_packed_detection_event_data,
bit_packed_shots=True,
bit_packed_predictions=True,
)
class SinterDecoder_BeamSearch(Decoder):
def __init__(
self,
max_rounds: int = 10,
beam_width: int = 8,
num_results: int = 1,
initial_iters: int = 30,
iters_per_round: int = 20,
**bp_kwargs,
):
"""Class for decoding stim circuits with sinter using belief propagation (BP).
This class uses Joschka Roffe's BP+OSD decoder as a subroutine. For more information on the options and
implementation of the BP+OSD subroutine, see the documentation of the LDPC library: https://roffe.eu/software/ldpc/index.html.
Additional keyword arguments are passed to the ``bp_decoder`` class of the ldpc Python package.
Parameters
----------
model : stim.DetectorErrorModel
The detector error model of the stim circuit to be decoded
max_bp_iters : int, optional
The maximum number of iterations of belief propagation to be used, by default 30
"""
self.max_rounds = max_rounds
self.beam_width = beam_width
self.num_results = num_results
self.initial_iters = initial_iters
self.iters_per_round = iters_per_round
self.bp_kwargs = bp_kwargs
def compile_decoder_for_dem(self, *, dem: stim.DetectorErrorModel) -> CompiledDecoder:
beamsearch = BeamSearch(
model=dem,
max_rounds=self.max_rounds,
beam_width=self.beam_width,
num_results=self.num_results,
initial_iters=self.initial_iters,
iters_per_round=self.iters_per_round,
**self.bp_kwargs,
)
return SinterCompiledDecoder_BeamSearch(beamsearch)
def decode_via_files(
self,
*,
num_shots: int,
num_dets: int,
num_obs: int,
dem_path: pathlib.Path,
dets_b8_in_path: pathlib.Path,
obs_predictions_b8_out_path: pathlib.Path,
tmp_dir: pathlib.Path,
) -> None:
"""Performs decoding by reading problems from, and writing solutions to, file paths.
Args:
num_shots: The number of times the circuit was sampled. The number of problems
to be solved.
num_dets: The number of detectors in the circuit. The number of detection event
bits in each shot.
num_obs: The number of observables in the circuit. The number of predicted bits
in each shot.
dem_path: The file path where the detector error model should be read from,
e.g. using `stim.DetectorErrorModel.from_file`. The error mechanisms
specified by the detector error model should be used to configure the
decoder.
dets_b8_in_path: The file path that detection event data should be read from.
Note that the file may be a named pipe instead of a fixed size object.
The detection events will be in b8 format (see
https://github.com/quantumlib/Stim/blob/main/doc/result_formats.md ). The
number of detection events per shot is available via the `num_dets`
argument or via the detector error model at `dem_path`.
obs_predictions_b8_out_path: The file path that decoder predictions must be
written to. The predictions must be written in b8 format (see
https://github.com/quantumlib/Stim/blob/main/doc/result_formats.md ). The
number of observables per shot is available via the `num_obs` argument or
via the detector error model at `dem_path`.
tmp_dir: Any temporary files generated by the decoder during its operation MUST
be put into this directory. The reason for this requirement is because
sinter is allowed to kill the decoding process without warning, without
giving it time to clean up any temporary objects. All cleanup should be done
via sinter deleting this directory after killing the decoder.
"""
dem = stim.DetectorErrorModel.from_file(dem_path)
beamsearch = BeamSearch(
model=dem,
max_rounds=self.max_rounds,
beam_width=self.beam_width,
num_results=self.num_results,
initial_iters=self.initial_iters,
iters_per_round=self.iters_per_round,
**self.bp_kwargs,
)
shots = stim.read_shot_data_file(
path=dets_b8_in_path,
format="b8",
num_detectors=dem.num_detectors,
bit_packed=False,
)
predictions = beamsearch.decode_batch(shots)
stim.write_shot_data_file(
data=predictions,
path=obs_predictions_b8_out_path,
format="b8",
num_observables=dem.num_observables,
)
def sinter_decoders() -> Dict[str, Decoder]:
return {"beamsearch": SinterDecoder_BeamSearch()}