Skip to content

Commit dc11cb0

Browse files
authored
Report errors with original DEM indexing (#186)
A big flaw in Tesseract's API is that error indices are always in terms of the internal `errors` vector, which does not correspond directly to the errors in the (flattened) input DEM given by the user. This causes a variety of problems. One such problem is that if we generate a DEM with `--dem-out` (or e.g. similar manual processing steps in python) it may bear little resemblance to the input DEM. For example all the targets will be stripped of separators etc. This makes it annoying to use tesseract-calibrated error models for downstream tasks like matching-based decoding. Here we adopt the principle that the user interface to Tesseract/Simplex decoders should always be in terms of the error indices from the original flattened DEM as provided by the user. This is now true across C++, CLI, and Python APIs. - Added index-mapping support to DEM preprocessing in common: - merge_indistinguishable_errors(..., error_index_map) - remove_zero_probability_errors(..., error_index_map) - `error_index_map` maps original error index to new preprocessed index - `error_index_map` maps removed / redundant errors to std::numeric_limits<size_t>::max() - Update both decoders (TesseractDecoder, SimplexDecoder) to maintain: - dem_error_to_error (original flattened DEM index -> internal index) - error_to_dem_error (internal error index -> original flattened DEM index) - predicted_errors_buffer reports errors back with original flattened DEM error indices. - cost_from_errors and observables-from-errors methods now: - accept original flattened DEM indices - throw on unmapped/removed indices (size_t::max()) - Updated Python bindings to use the new helpers. - Updated pybind common wrappers to pass required map args to common preprocessing functions. - Updated --dem-out in both CLI binaries: - keep original flattened DEM in scope - emit updated probabilities by iterating original DEM instruction order - preserve original error instruction tags and arbitrary formatting (e.g. `D0 ^ D0 D1`) when writing estimated DEM output. - Updated tests - Added AGENTS guidance to run Python Bazel tests
1 parent 4076180 commit dc11cb0

15 files changed

Lines changed: 272 additions & 148 deletions

AGENTS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ bazel build src:tesseract src:simplex
1919

2020
```bash
2121
bazel test src:all
22+
bazel test //src/py:all
2223
```
2324

2425
## Building with CMake

src/common.cc

Lines changed: 60 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14-
1514
#include "common.h"
1615

1716
#include <iomanip>
1817
#include <iostream>
18+
#include <limits>
1919
#include <sstream>
2020
#include <string>
2121
#include <vector>
@@ -117,11 +117,15 @@ double common::merge_weights(double a, double b) {
117117
}
118118

119119
stim::DetectorErrorModel common::merge_indistinguishable_errors(
120-
const stim::DetectorErrorModel& dem) {
120+
const stim::DetectorErrorModel& dem, std::vector<size_t>& error_index_map) {
121121
stim::DetectorErrorModel out_dem;
122122

123-
// Map to track the distinct symptoms
124-
std::unordered_map<Symptom, Error, Symptom::hash> errors_by_symptom;
123+
error_index_map.clear();
124+
125+
// Map to track first-seen distinct symptoms.
126+
std::unordered_map<Symptom, size_t, Symptom::hash> merged_index_by_symptom;
127+
std::vector<Error> merged_errors;
128+
125129
for (const stim::DemInstruction& instruction : dem.flattened().instructions) {
126130
switch (instruction.type) {
127131
case stim::DemInstructionType::DEM_ERROR: {
@@ -131,11 +135,18 @@ stim::DetectorErrorModel common::merge_indistinguishable_errors(
131135
std::cout << "Warning: the circuit has errors that do not flip any detectors \n";
132136
}
133137

134-
if (errors_by_symptom.find(error.symptom) != errors_by_symptom.end()) {
135-
error.likelihood_cost = merge_weights(error.likelihood_cost,
136-
errors_by_symptom[error.symptom].likelihood_cost);
138+
auto it = merged_index_by_symptom.find(error.symptom);
139+
if (it != merged_index_by_symptom.end()) {
140+
size_t merged_error_index = it->second;
141+
merged_errors[merged_error_index].likelihood_cost = merge_weights(
142+
error.likelihood_cost, merged_errors[merged_error_index].likelihood_cost);
143+
error_index_map.push_back(merged_error_index);
144+
} else {
145+
size_t merged_error_index = merged_errors.size();
146+
merged_index_by_symptom[error.symptom] = merged_error_index;
147+
merged_errors.push_back(error);
148+
error_index_map.push_back(merged_error_index);
137149
}
138-
errors_by_symptom[error.symptom] = error;
139150
break;
140151
}
141152
case stim::DemInstructionType::DEM_DETECTOR: {
@@ -150,22 +161,27 @@ stim::DetectorErrorModel common::merge_indistinguishable_errors(
150161
throw std::invalid_argument("Unrecognized instruction type: " + instruction.str());
151162
}
152163
}
153-
for (const auto& it : errors_by_symptom) {
154-
out_dem.append_error_instruction(it.second.get_probability(),
155-
it.second.symptom.as_dem_instruction_targets(),
164+
for (const auto& error : merged_errors) {
165+
out_dem.append_error_instruction(error.get_probability(),
166+
error.symptom.as_dem_instruction_targets(),
156167
/*tag=*/"");
157168
}
158169
return out_dem;
159170
}
160171

161172
stim::DetectorErrorModel common::remove_zero_probability_errors(
162-
const stim::DetectorErrorModel& dem) {
173+
const stim::DetectorErrorModel& dem, std::vector<size_t>& error_index_map) {
163174
stim::DetectorErrorModel out_dem;
175+
error_index_map.clear();
176+
size_t output_error_index = 0;
164177
for (const stim::DemInstruction& instruction : dem.flattened().instructions) {
165178
switch (instruction.type) {
166179
case stim::DemInstructionType::DEM_ERROR:
167180
if (instruction.arg_data[0] > 0) {
168181
out_dem.append_dem_instruction(instruction);
182+
error_index_map.push_back(output_error_index++);
183+
} else {
184+
error_index_map.push_back(std::numeric_limits<size_t>::max());
169185
}
170186
break;
171187
case stim::DemInstructionType::DEM_DETECTOR:
@@ -181,43 +197,46 @@ stim::DetectorErrorModel common::remove_zero_probability_errors(
181197
return out_dem;
182198
}
183199

184-
stim::DetectorErrorModel common::dem_from_counts(stim::DetectorErrorModel& orig_dem,
200+
void common::chain_error_maps(std::vector<size_t>& base_map, const std::vector<size_t>& next_map) {
201+
for (size_t& ei : base_map) {
202+
if (ei != std::numeric_limits<size_t>::max()) {
203+
ei = next_map[ei];
204+
}
205+
}
206+
}
207+
208+
std::vector<size_t> common::invert_error_map(const std::vector<size_t>& error_map,
209+
size_t num_output_errors) {
210+
std::vector<size_t> inverted_map(num_output_errors, std::numeric_limits<size_t>::max());
211+
for (size_t i = 0; i < error_map.size(); ++i) {
212+
size_t mapped_index = error_map[i];
213+
if (mapped_index != std::numeric_limits<size_t>::max() &&
214+
inverted_map[mapped_index] == std::numeric_limits<size_t>::max()) {
215+
inverted_map[mapped_index] = i;
216+
}
217+
}
218+
return inverted_map;
219+
}
220+
221+
stim::DetectorErrorModel common::dem_from_counts(const stim::DetectorErrorModel& orig_dem,
185222
const std::vector<size_t>& error_counts,
186223
size_t num_shots) {
187-
if (orig_dem.count_errors() != error_counts.size()) {
224+
stim::DetectorErrorModel flat_dem = orig_dem.flattened();
225+
if (flat_dem.count_errors() != error_counts.size()) {
188226
throw std::invalid_argument(
189227
"Error hits array must be the same size as the number of errors in the "
190228
"original DEM.");
191229
}
192230

193-
for (const stim::DemInstruction& instruction : orig_dem.flattened().instructions) {
194-
if (instruction.type == stim::DemInstructionType::DEM_ERROR && instruction.arg_data[0] == 0) {
195-
throw std::invalid_argument(
196-
"dem_from_counts requires DEMs without zero-probability errors. Use"
197-
" remove_zero_probability_errors first.");
198-
}
199-
}
200-
201231
stim::DetectorErrorModel out_dem;
202-
size_t ei = 0;
203-
for (const stim::DemInstruction& instruction : orig_dem.flattened().instructions) {
204-
switch (instruction.type) {
205-
case stim::DemInstructionType::DEM_ERROR: {
206-
double est_probability = double(error_counts.at(ei)) / double(num_shots);
207-
out_dem.append_error_instruction(est_probability, instruction.target_data, /*tag=*/"");
208-
++ei;
209-
break;
210-
}
211-
case stim::DemInstructionType::DEM_DETECTOR: {
212-
out_dem.append_dem_instruction(instruction);
213-
break;
214-
}
215-
case stim::DemInstructionType::DEM_LOGICAL_OBSERVABLE: {
216-
out_dem.append_dem_instruction(instruction);
217-
break;
218-
}
219-
default:
220-
throw std::invalid_argument("Unrecognized instruction type: " + instruction.str());
232+
size_t error_index = 0;
233+
for (const stim::DemInstruction& instruction : flat_dem.instructions) {
234+
if (instruction.type == stim::DemInstructionType::DEM_ERROR) {
235+
double est_probability = double(error_counts.at(error_index)) / double(num_shots);
236+
out_dem.append_error_instruction(est_probability, instruction.target_data, instruction.tag);
237+
++error_index;
238+
} else {
239+
out_dem.append_dem_instruction(instruction);
221240
}
222241
}
223242
return out_dem;

src/common.h

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,31 @@ struct Error {
7373

7474
// Makes a new (flattened) dem where identical error mechanisms have been
7575
// merged.
76-
stim::DetectorErrorModel merge_indistinguishable_errors(const stim::DetectorErrorModel& dem);
76+
// `error_index_map[old_error_index]` gives the corresponding merged DEM error
77+
// index in the returned DEM.
78+
stim::DetectorErrorModel merge_indistinguishable_errors(const stim::DetectorErrorModel& dem,
79+
std::vector<size_t>& error_index_map);
7780

7881
// Returns a copy of the given error model with any zero-probability DEM_ERROR
7982
// instructions removed.
80-
stim::DetectorErrorModel remove_zero_probability_errors(const stim::DetectorErrorModel& dem);
83+
// `error_index_map[old_error_index]` gives the corresponding retained DEM error
84+
// index in the returned DEM, or `std::numeric_limits<size_t>::max()` if the
85+
// error was removed.
86+
stim::DetectorErrorModel remove_zero_probability_errors(const stim::DetectorErrorModel& dem,
87+
std::vector<size_t>& error_index_map);
88+
89+
// Updates the base_map by chaining it with next_map.
90+
// base_map[i] = next_map[base_map[i]]
91+
void chain_error_maps(std::vector<size_t>& base_map, const std::vector<size_t>& next_map);
92+
93+
// Inverts the error_map to create a mapping from output error indices back to
94+
// the first original error index that maps to it.
95+
std::vector<size_t> invert_error_map(const std::vector<size_t>& error_map,
96+
size_t num_output_errors);
8197

8298
// Makes a new dem where the probabilities of errors are estimated from the
8399
// fraction of shots they were used in.
84-
// Throws std::invalid_argument if `orig_dem` contains zero-probability errors;
85-
// call remove_zero_probability_errors first.
86-
stim::DetectorErrorModel dem_from_counts(stim::DetectorErrorModel& orig_dem,
100+
stim::DetectorErrorModel dem_from_counts(const stim::DetectorErrorModel& orig_dem,
87101
const std::vector<size_t>& error_counts, size_t num_shots);
88102

89103
/// Computes the weight of an edge resulting from merging edges with weight `a' and weight `b',

src/common.pybind.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ void add_common_module(py::module& root) {
142142
"merge_indistinguishable_errors",
143143
[](py::object dem) {
144144
auto input_dem = parse_py_object<stim::DetectorErrorModel>(dem);
145-
auto res = common::merge_indistinguishable_errors(input_dem);
145+
std::vector<size_t> error_index_map;
146+
auto res = common::merge_indistinguishable_errors(input_dem, error_index_map);
146147
return make_py_object(res, "DetectorErrorModel");
147148
},
148149
py::arg("dem"), R"pbdoc(
@@ -166,9 +167,13 @@ void add_common_module(py::module& root) {
166167
m.def(
167168
"remove_zero_probability_errors",
168169
[](py::object dem) {
169-
return make_py_object(
170-
common::remove_zero_probability_errors(parse_py_object<stim::DetectorErrorModel>(dem)),
171-
"DetectorErrorModel");
170+
return make_py_object(([&]() {
171+
std::vector<size_t> error_index_map;
172+
return common::remove_zero_probability_errors(
173+
parse_py_object<stim::DetectorErrorModel>(dem),
174+
error_index_map);
175+
})(),
176+
"DetectorErrorModel");
172177
},
173178
py::arg("dem"), R"pbdoc(
174179
Removes errors with a probability of 0 from a `stim.DetectorErrorModel`.

src/common.test.cc

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ TEST(common, ErrorsStructFromDemInstruction) {
2626
EXPECT_EQ(ES.symptom.observables, std::vector<int>{0});
2727
}
2828

29-
TEST(common, DemFromCountsRejectsZeroProbabilityErrors) {
29+
TEST(common, DemFromCountsHandlesZeroProbabilityErrors) {
3030
stim::DetectorErrorModel dem(R"DEM(
3131
error(0.1) D0
3232
error(0) D1
@@ -38,20 +38,32 @@ TEST(common, DemFromCountsRejectsZeroProbabilityErrors) {
3838

3939
std::vector<size_t> counts{1, 7, 4};
4040
size_t num_shots = 10;
41-
EXPECT_THROW({ common::dem_from_counts(dem, counts, num_shots); }, std::invalid_argument);
42-
43-
stim::DetectorErrorModel cleaned = common::remove_zero_probability_errors(dem);
44-
stim::DetectorErrorModel out_dem =
45-
common::dem_from_counts(cleaned, std::vector<size_t>{1, 4}, num_shots);
41+
stim::DetectorErrorModel out_dem = common::dem_from_counts(dem, counts, num_shots);
4642

4743
auto flat = out_dem.flattened();
48-
ASSERT_EQ(out_dem.count_errors(), 2);
49-
ASSERT_GE(flat.instructions.size(), 2);
44+
ASSERT_EQ(out_dem.count_errors(), 3);
45+
ASSERT_GE(flat.instructions.size(), 3);
5046

5147
EXPECT_EQ(flat.instructions[0].type, stim::DemInstructionType::DEM_ERROR);
5248
EXPECT_NEAR(flat.instructions[0].arg_data[0], 0.1, 1e-9);
53-
ASSERT_EQ(flat.instructions[1].type, stim::DemInstructionType::DEM_ERROR);
54-
EXPECT_NEAR(flat.instructions[1].arg_data[0], 0.4, 1e-9);
49+
EXPECT_EQ(flat.instructions[1].type, stim::DemInstructionType::DEM_ERROR);
50+
EXPECT_NEAR(flat.instructions[1].arg_data[0], 0.7, 1e-9);
51+
EXPECT_EQ(flat.instructions[2].type, stim::DemInstructionType::DEM_ERROR);
52+
EXPECT_NEAR(flat.instructions[2].arg_data[0], 0.4, 1e-9);
53+
54+
std::vector<size_t> error_index_map;
55+
stim::DetectorErrorModel cleaned = common::remove_zero_probability_errors(dem, error_index_map);
56+
stim::DetectorErrorModel out_dem_cleaned =
57+
common::dem_from_counts(cleaned, std::vector<size_t>{1, 4}, num_shots);
58+
59+
auto flat_cleaned = out_dem_cleaned.flattened();
60+
ASSERT_EQ(out_dem_cleaned.count_errors(), 2);
61+
ASSERT_GE(flat_cleaned.instructions.size(), 2);
62+
63+
EXPECT_EQ(flat_cleaned.instructions[0].type, stim::DemInstructionType::DEM_ERROR);
64+
EXPECT_NEAR(flat_cleaned.instructions[0].arg_data[0], 0.1, 1e-9);
65+
ASSERT_EQ(flat_cleaned.instructions[1].type, stim::DemInstructionType::DEM_ERROR);
66+
EXPECT_NEAR(flat_cleaned.instructions[1].arg_data[0], 0.4, 1e-9);
5567
}
5668

5769
TEST(common, DemFromCountsSimpleTwoErrors) {
@@ -86,7 +98,8 @@ TEST(common, RemoveZeroProbabilityErrors) {
8698
detector(0, 0, 0) D2
8799
)DEM");
88100

89-
stim::DetectorErrorModel cleaned = common::remove_zero_probability_errors(dem);
101+
std::vector<size_t> error_index_map;
102+
stim::DetectorErrorModel cleaned = common::remove_zero_probability_errors(dem, error_index_map);
90103

91104
EXPECT_EQ(cleaned.count_errors(), 2);
92105
auto flat = cleaned.flattened();
@@ -153,30 +166,31 @@ TEST(CommonTest, merge_indistinguishable_errors_two_errors) {
153166
double p2 = 0.2;
154167
double expected_merged_p = p1 * (1 - p2) + p2 * (1 - p1);
155168
auto dem1 = create_dem_with_two_errors(p1, p2);
156-
auto merged_dem1 = common::merge_indistinguishable_errors(dem1);
169+
std::vector<size_t> error_index_map;
170+
auto merged_dem1 = common::merge_indistinguishable_errors(dem1, error_index_map);
157171
ASSERT_NEAR(get_merged_probability(merged_dem1), expected_merged_p, 1e-9);
158172

159173
// Case 2: One low, one high probability.
160174
p1 = 0.1;
161175
p2 = 0.8;
162176
expected_merged_p = p1 * (1 - p2) + p2 * (1 - p1);
163177
auto dem2 = create_dem_with_two_errors(p1, p2);
164-
auto merged_dem2 = common::merge_indistinguishable_errors(dem2);
178+
auto merged_dem2 = common::merge_indistinguishable_errors(dem2, error_index_map);
165179
ASSERT_NEAR(get_merged_probability(merged_dem2), expected_merged_p, 1e-9);
166180

167181
// Case 3: One high, one low probability.
168182
p1 = 0.8;
169183
p2 = 0.1;
170184
expected_merged_p = p1 * (1 - p2) + p2 * (1 - p1);
171185
auto dem3 = create_dem_with_two_errors(p1, p2);
172-
auto merged_dem3 = common::merge_indistinguishable_errors(dem3);
186+
auto merged_dem3 = common::merge_indistinguishable_errors(dem3, error_index_map);
173187
ASSERT_NEAR(get_merged_probability(merged_dem3), expected_merged_p, 1e-9);
174188

175189
// Case 4: Both probabilities are high.
176190
p1 = 0.8;
177191
p2 = 0.9;
178192
expected_merged_p = p1 * (1 - p2) + p2 * (1 - p1);
179193
auto dem4 = create_dem_with_two_errors(p1, p2);
180-
auto merged_dem4 = common::merge_indistinguishable_errors(dem4);
194+
auto merged_dem4 = common::merge_indistinguishable_errors(dem4, error_index_map);
181195
ASSERT_NEAR(get_merged_probability(merged_dem4), expected_merged_p, 1e-9);
182196
}

0 commit comments

Comments
 (0)