Skip to content

Commit 27cffb5

Browse files
committed
fix: resolve C++ multi-pass reindexing, seed alignment, and degeneracy mapping bugs
TAG=agy
1 parent e165045 commit 27cffb5

4 files changed

Lines changed: 515 additions & 373 deletions

File tree

src/multi_pass_tesseract_decoder.cc

Lines changed: 128 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,36 @@ MultiPassTesseractDecoder::MultiPassTesseractDecoder(
2525
initialize(dem, classifier);
2626
}
2727

28+
void MultiPassTesseractDecoder::validate_annotations(const stim::DetectorErrorModel& dem,
29+
const DetectorClassifier& classifier) {
30+
stim::DetectorErrorModel flattened = dem.flattened();
31+
size_t total_global_detectors = (size_t)flattened.count_detectors();
32+
33+
std::set<uint64_t> all_ids;
34+
std::map<uint64_t, std::string> tags;
35+
for (const auto& inst : flattened.instructions) {
36+
if (inst.type == stim::DemInstructionType::DEM_DETECTOR) {
37+
uint64_t d = inst.target_data[0].val();
38+
all_ids.insert(d);
39+
tags[d] = inst.tag;
40+
}
41+
}
42+
auto coords_map = flattened.get_detector_coordinates(all_ids);
43+
44+
std::set<int> unique_classes;
45+
for (size_t i = 0; i < total_global_detectors; ++i) {
46+
std::vector<double> c = coords_map.count(i) ? coords_map.at(i) : std::vector<double>{};
47+
std::string t = tags.count(i) ? tags.at(i) : "";
48+
int cls = classifier((int)i, c, t);
49+
if (cls != -1) unique_classes.insert(cls);
50+
}
51+
if (unique_classes.size() < 2) {
52+
throw std::invalid_argument(
53+
"Multi-pass decoding requires an annotated Stim circuit/DEM with at least "
54+
"2 stabilizer components.");
55+
}
56+
}
57+
2858
void MultiPassTesseractDecoder::initialize(const stim::DetectorErrorModel& dem,
2959
const DetectorClassifier& classifier) {
3060
stim::DetectorErrorModel flattened = dem.flattened();
@@ -63,11 +93,6 @@ void MultiPassTesseractDecoder::initialize(const stim::DetectorErrorModel& dem,
6393
for (int c : unique_classes) class_to_comp_id[c] = next_comp_id++;
6494

6595
size_t num_components = unique_classes.size();
66-
if (num_components < 2) {
67-
throw std::invalid_argument(
68-
"Multi-pass decoding requires an annotated Stim circuit/DEM with at least 2 stabilizer "
69-
"components.");
70-
}
7196
component_decoders.resize(num_components);
7297

7398
global_det_to_comp_id.assign(total_global_detectors, -1);
@@ -93,71 +118,48 @@ void MultiPassTesseractDecoder::initialize(const stim::DetectorErrorModel& dem,
93118
for (size_t i = 0; i < component_decoders.size(); ++i) {
94119
auto& cd = component_decoders[i];
95120

96-
std::vector<int> sorted_global_dets(cd.component_detectors.begin(),
97-
cd.component_detectors.end());
98-
std::sort(sorted_global_dets.begin(), sorted_global_dets.end());
99-
for (size_t local_idx = 0; local_idx < sorted_global_dets.size(); ++local_idx) {
100-
cd.global_to_local_det[sorted_global_dets[local_idx]] = (int)local_idx;
121+
for (size_t global_d = 0; global_d < total_global_detectors; ++global_d) {
122+
cd.global_to_local_det[global_d] = (int)global_d;
101123
}
102124

103125
stim::DetectorErrorModel local_dem;
104-
// MUST append detector instructions for ALL local detectors first to set count_detectors()
105-
// correctly
106-
for (size_t local_idx = 0; local_idx < sorted_global_dets.size(); ++local_idx) {
107-
int global_d = sorted_global_dets[local_idx];
126+
for (size_t global_d = 0; global_d < total_global_detectors; ++global_d) {
108127
std::vector<double> c =
109128
coords_map.count(global_d) ? coords_map.at(global_d) : std::vector<double>{};
110129
std::string t = tags.count(global_d) ? tags.at(global_d) : "";
111-
local_dem.append_detector_instruction(c, stim::DemTarget::relative_detector_id(local_idx), t);
130+
local_dem.append_detector_instruction(c, stim::DemTarget::relative_detector_id(global_d), t);
112131
}
113132

114133
for (const auto& inst : component_dems_raw[i].instructions) {
115134
if (inst.type == stim::DemInstructionType::DEM_ERROR) {
116-
std::vector<stim::DemTarget> local_targets;
117135
bool has_obs = false;
118136
for (const auto& t : inst.target_data) {
119-
if (t.is_relative_detector_id()) {
120-
int global_d = t.val();
121-
local_targets.push_back(
122-
stim::DemTarget::relative_detector_id(cd.global_to_local_det.at(global_d)));
123-
} else {
124-
local_targets.push_back(t);
125-
if (t.is_observable_id()) has_obs = true;
126-
}
137+
if (t.is_observable_id()) has_obs = true;
127138
}
128139
if (has_obs) cd.affects_observable = true;
129-
local_dem.append_error_instruction(inst.arg_data[0], local_targets, inst.tag);
140+
local_dem.append_error_instruction(inst.arg_data[0], inst.target_data, inst.tag);
130141
} else if (inst.type == stim::DemInstructionType::DEM_LOGICAL_OBSERVABLE) {
131142
local_dem.append_dem_instruction(inst);
132143
}
133144
}
134145

135-
// std::cout << "DEBUG: local_dem " << i << " : " << local_dem << std::endl;
136-
137146
TesseractConfig config = base_config;
138147
config.dem = local_dem;
139148
config.merge_errors = true;
140-
config.det_orders = build_det_orders(config.dem, num_det_orders, det_order_method, seed + i);
149+
config.det_orders = build_det_orders(config.dem, num_det_orders, det_order_method, seed);
141150

142151
cd.decoder = std::make_unique<TesseractDecoder>(config);
143-
// std::cout << "DEBUG: Component " << i << " initialized with " << cd.decoder->errors.size() <<
144-
// " errors and " << config.dem.count_detectors() << " detectors." << std::endl;
145-
/*
146-
for (size_t ei = 0; ei < cd.decoder->errors.size(); ei++) {
147-
// std::cout << " Comp " << i << " Err " << ei << ": D";
148-
for (int d : cd.decoder->errors[ei].symptom.detectors) // std::cout << d << " ";
149-
// std::cout << std::endl;
152+
if (base_config.verbose) {
153+
std::cout << "DEBUG: Component " << i << " initialized with " << cd.decoder->errors.size()
154+
<< " errors and " << config.dem.count_detectors() << " detectors." << std::endl;
150155
}
151-
*/
152156
cd.error_index_to_rules.resize(cd.decoder->errors.size());
153157

154158
for (size_t ei = 0; ei < cd.decoder->errors.size(); ++ei) {
155159
cd.original_costs.push_back(cd.decoder->errors[ei].likelihood_cost);
156-
Hyperedge local_symptom = cd.decoder->errors[ei].symptom.detectors;
157-
Hyperedge global_symptom;
158-
for (int local_d : local_symptom) global_symptom.push_back(sorted_global_dets[local_d]);
160+
Hyperedge global_symptom = cd.decoder->errors[ei].symptom.detectors;
159161
std::sort(global_symptom.begin(), global_symptom.end());
160-
cd.symptom_to_error_index[global_symptom] = ei;
162+
cd.symptom_to_error_index[global_symptom].push_back(ei);
161163
}
162164
}
163165

@@ -167,19 +169,27 @@ void MultiPassTesseractDecoder::initialize(const stim::DetectorErrorModel& dem,
167169
int causal_comp = -1;
168170
if (!causal_symptom.empty()) causal_comp = global_det_to_comp_id[causal_symptom[0]];
169171
if (causal_comp == -1) continue;
172+
170173
auto it = component_decoders[causal_comp].symptom_to_error_index.find(causal_symptom);
171174
if (it == component_decoders[causal_comp].symptom_to_error_index.end()) continue;
172-
size_t causal_err_idx = it->second;
173-
for (const auto& imp : implied_probs) {
174-
Hyperedge target_symptom = imp.affected_hyperedge;
175-
std::sort(target_symptom.begin(), target_symptom.end());
176-
int target_comp = -1;
177-
if (!target_symptom.empty()) target_comp = global_det_to_comp_id[target_symptom[0]];
178-
if (target_comp == -1) continue;
179-
auto t_it = component_decoders[target_comp].symptom_to_error_index.find(target_symptom);
180-
if (t_it != component_decoders[target_comp].symptom_to_error_index.end()) {
181-
component_decoders[causal_comp].error_index_to_rules[causal_err_idx].push_back(
182-
{(size_t)target_comp, t_it->second, imp.probability});
175+
176+
// Loop through all degenerate causal error indices!
177+
for (size_t causal_err_idx : it->second) {
178+
for (const auto& imp : implied_probs) {
179+
Hyperedge target_symptom = imp.affected_hyperedge;
180+
std::sort(target_symptom.begin(), target_symptom.end());
181+
int target_comp = -1;
182+
if (!target_symptom.empty()) target_comp = global_det_to_comp_id[target_symptom[0]];
183+
if (target_comp == -1) continue;
184+
185+
auto t_it = component_decoders[target_comp].symptom_to_error_index.find(target_symptom);
186+
if (t_it != component_decoders[target_comp].symptom_to_error_index.end()) {
187+
// Loop through all degenerate target error indices and add rules to each!
188+
for (size_t target_err_idx : t_it->second) {
189+
component_decoders[causal_comp].error_index_to_rules[causal_err_idx].push_back(
190+
{(size_t)target_comp, target_err_idx, imp.probability});
191+
}
192+
}
183193
}
184194
}
185195
}
@@ -240,71 +250,99 @@ void MultiPassTesseractDecoder::build_causal_schedule() {
240250

241251
std::vector<int> MultiPassTesseractDecoder::decode(const std::vector<uint64_t>& detections) {
242252
last_shot_num_reweights = 0;
243-
// 1. Multi-Pass Loop: Earlier passes only bias the final pass.
253+
254+
// 1. Multi-Pass Loop: Sequentially schedules component passes and propagates priors.
244255
for (size_t pass = 0; pass < num_passes; ++pass) {
245256
bool is_final_pass = (pass == num_passes - 1);
246257

258+
// Decode scheduled components for the current pass layer using persistent local buffers.
247259
for (size_t comp_idx : pass_schedule[pass]) {
248260
auto& cd = component_decoders[comp_idx];
249261
std::vector<uint64_t> local_dets;
250262
for (uint64_t d : detections) {
251-
if (cd.global_to_local_det.count((int)d)) {
252-
local_dets.push_back((uint64_t)cd.global_to_local_det.at((int)d));
263+
if (cd.component_detectors.count((int)d)) {
264+
local_dets.push_back(d);
253265
}
254266
}
255267

256-
// Perform decoding for this component in this pass.
257268
cd.decoder->decode_to_errors(local_dets);
269+
component_predictions[comp_idx] = cd.decoder->predicted_errors_buffer;
270+
}
258271

259-
if (is_final_pass) {
260-
// Track components that decode in the final pass for extraction.
261-
final_pass_active_components.push_back(comp_idx);
262-
} else {
263-
// If this is NOT the final pass, use the results for reweighting, then discard them.
272+
if (!is_final_pass) {
273+
// Step A: Apply Damped Fractional Memory to previously modified priors.
274+
// Smoothly decay current modifications back toward the baseline to prevent message
275+
// saturation.
276+
double gamma =
277+
0.5; // Tunable decay factor: 1.0 is strict isolation, 0.0 is full accumulation.
278+
279+
for (size_t m_comp_idx : modified_component_indices) {
280+
auto& cd = component_decoders[m_comp_idx];
281+
if (!cd.shot_all_modified_error_indices.empty()) {
282+
for (size_t idx : cd.shot_all_modified_error_indices) {
283+
double baseline_cost = cd.original_costs[idx];
284+
double current_cost = cd.decoder->errors[idx].likelihood_cost;
285+
cd.decoder->errors[idx].likelihood_cost =
286+
gamma * baseline_cost + (1.0 - gamma) * current_cost;
287+
}
288+
cd.decoder->update_internal_costs(cd.shot_all_modified_error_indices);
289+
// Retain tracking indices so the final Surgical Reset completely clears cross-shot state.
290+
}
291+
}
292+
293+
// Step B: Broadcast reweighting rules derived strictly from the latest predictions.
294+
for (size_t comp_idx : pass_schedule[pass]) {
295+
auto& cd = component_decoders[comp_idx];
264296
for (size_t dem_err_idx : cd.decoder->predicted_errors_buffer) {
265297
size_t internal_err_idx = cd.decoder->dem_error_to_error.at(dem_err_idx);
266298
if (internal_err_idx == std::numeric_limits<size_t>::max()) continue;
267299

268300
for (const auto& rule : cd.error_index_to_rules[internal_err_idx]) {
269301
auto& target_cd = component_decoders[rule.target_comp_idx];
270302

271-
// Track modified components only once per shot.
272-
if (target_cd.modified_error_indices.empty()) {
273-
modified_component_indices.push_back(rule.target_comp_idx);
274-
}
303+
modified_component_indices.push_back(rule.target_comp_idx);
275304

276-
// Cap probability at 0.499 to prevent negative costs in the engine.
277-
target_cd.decoder->errors[rule.target_error_idx].set_with_probability(
278-
std::min(rule.conditional_prob, 0.499));
279-
target_cd.modified_error_indices.push_back(rule.target_error_idx);
280-
last_shot_num_reweights++;
305+
// Apply Max-Prob Rule safely for concurrent rules within this pass layer.
306+
double current_p = target_cd.decoder->errors[rule.target_error_idx].get_probability();
307+
if (rule.conditional_prob > current_p) {
308+
target_cd.decoder->errors[rule.target_error_idx].set_with_probability(
309+
std::min(rule.conditional_prob, 0.5));
310+
target_cd.shot_all_modified_error_indices.push_back(rule.target_error_idx);
311+
last_shot_num_reweights++;
312+
}
281313
}
282314
}
283-
// Clear the buffer so these intermediate decisions don't contribute to the final
284-
// prediction.
285-
cd.decoder->predicted_errors_buffer.clear();
286315
}
287-
}
288316

289-
// Sync modified costs for the next pass.
290-
if (!is_final_pass) {
317+
// Step C: Deduplicate modified tracking vectors and synchronize internal graph costs.
318+
std::sort(modified_component_indices.begin(), modified_component_indices.end());
319+
modified_component_indices.erase(
320+
std::unique(modified_component_indices.begin(), modified_component_indices.end()),
321+
modified_component_indices.end());
322+
291323
for (size_t m_comp_idx : modified_component_indices) {
292324
auto& cd = component_decoders[m_comp_idx];
293-
if (!cd.modified_error_indices.empty()) {
294-
cd.decoder->update_internal_costs(cd.modified_error_indices);
325+
if (!cd.shot_all_modified_error_indices.empty()) {
326+
std::sort(cd.shot_all_modified_error_indices.begin(),
327+
cd.shot_all_modified_error_indices.end());
328+
cd.shot_all_modified_error_indices.erase(
329+
std::unique(cd.shot_all_modified_error_indices.begin(),
330+
cd.shot_all_modified_error_indices.end()),
331+
cd.shot_all_modified_error_indices.end());
332+
cd.decoder->update_internal_costs(cd.shot_all_modified_error_indices);
295333
}
296334
}
297335
}
298336
}
299337

300-
// 2. Unified Logical Extraction: Collect final-pass predictions from only active components.
338+
// 2. Unified Logical Extraction: Collect final predictions from ALL components that ran during
339+
// the shot.
301340
std::set<int> flipped_observables;
302-
for (size_t comp_idx : final_pass_active_components) {
341+
for (const auto& [comp_idx, preds] : component_predictions) {
303342
auto& cd = component_decoders[comp_idx];
304-
if (cd.decoder->predicted_errors_buffer.empty()) continue;
343+
if (preds.empty()) continue;
305344

306-
std::vector<int> local_flips =
307-
cd.decoder->get_flipped_observables(cd.decoder->predicted_errors_buffer);
345+
std::vector<int> local_flips = cd.decoder->get_flipped_observables(preds);
308346
for (int obs : local_flips) {
309347
if (flipped_observables.count(obs))
310348
flipped_observables.erase(obs);
@@ -313,17 +351,19 @@ std::vector<int> MultiPassTesseractDecoder::decode(const std::vector<uint64_t>&
313351
}
314352
}
315353

316-
// 3. Surgical Reset: Restore modified costs for the next shot.
354+
// 3. Surgical Reset: Restore modified costs to leave the internal structures pristine for the
355+
// next shot.
317356
for (size_t m_comp_idx : modified_component_indices) {
318357
auto& cd = component_decoders[m_comp_idx];
319-
for (size_t idx : cd.modified_error_indices) {
320-
cd.decoder->errors[idx].likelihood_cost = cd.original_costs[idx];
358+
if (!cd.shot_all_modified_error_indices.empty()) {
359+
for (size_t idx : cd.shot_all_modified_error_indices) {
360+
cd.decoder->errors[idx].likelihood_cost = cd.original_costs[idx];
361+
}
362+
cd.decoder->update_internal_costs(cd.shot_all_modified_error_indices);
363+
cd.shot_all_modified_error_indices.clear();
321364
}
322-
cd.decoder->update_internal_costs(cd.modified_error_indices);
323-
cd.modified_error_indices.clear();
324365
}
325366

326-
// Clear shot-level tracking vectors.
327367
modified_component_indices.clear();
328368
final_pass_active_components.clear();
329369

0 commit comments

Comments
 (0)