@@ -25,6 +25,36 @@ MultiPassTesseractDecoder::MultiPassTesseractDecoder(
2525 initialize (dem, classifier);
2626}
2727
28+ void MultiPassTesseractDecoder::validate_annotations (const stim::DetectorErrorModel& dem,
29+ const DetectorClassifier& classifier) {
30+ stim::DetectorErrorModel flattened = dem.flattened ();
31+ size_t total_global_detectors = (size_t )flattened.count_detectors ();
32+
33+ std::set<uint64_t > all_ids;
34+ std::map<uint64_t , std::string> tags;
35+ for (const auto & inst : flattened.instructions ) {
36+ if (inst.type == stim::DemInstructionType::DEM_DETECTOR ) {
37+ uint64_t d = inst.target_data [0 ].val ();
38+ all_ids.insert (d);
39+ tags[d] = inst.tag ;
40+ }
41+ }
42+ auto coords_map = flattened.get_detector_coordinates (all_ids);
43+
44+ std::set<int > unique_classes;
45+ for (size_t i = 0 ; i < total_global_detectors; ++i) {
46+ std::vector<double > c = coords_map.count (i) ? coords_map.at (i) : std::vector<double >{};
47+ std::string t = tags.count (i) ? tags.at (i) : " " ;
48+ int cls = classifier ((int )i, c, t);
49+ if (cls != -1 ) unique_classes.insert (cls);
50+ }
51+ if (unique_classes.size () < 2 ) {
52+ throw std::invalid_argument (
53+ " Multi-pass decoding requires an annotated Stim circuit/DEM with at least "
54+ " 2 stabilizer components." );
55+ }
56+ }
57+
2858void MultiPassTesseractDecoder::initialize (const stim::DetectorErrorModel& dem,
2959 const DetectorClassifier& classifier) {
3060 stim::DetectorErrorModel flattened = dem.flattened ();
@@ -63,11 +93,6 @@ void MultiPassTesseractDecoder::initialize(const stim::DetectorErrorModel& dem,
6393 for (int c : unique_classes) class_to_comp_id[c] = next_comp_id++;
6494
6595 size_t num_components = unique_classes.size ();
66- if (num_components < 2 ) {
67- throw std::invalid_argument (
68- " Multi-pass decoding requires an annotated Stim circuit/DEM with at least 2 stabilizer "
69- " components." );
70- }
7196 component_decoders.resize (num_components);
7297
7398 global_det_to_comp_id.assign (total_global_detectors, -1 );
@@ -93,71 +118,48 @@ void MultiPassTesseractDecoder::initialize(const stim::DetectorErrorModel& dem,
93118 for (size_t i = 0 ; i < component_decoders.size (); ++i) {
94119 auto & cd = component_decoders[i];
95120
96- std::vector<int > sorted_global_dets (cd.component_detectors .begin (),
97- cd.component_detectors .end ());
98- std::sort (sorted_global_dets.begin (), sorted_global_dets.end ());
99- for (size_t local_idx = 0 ; local_idx < sorted_global_dets.size (); ++local_idx) {
100- cd.global_to_local_det [sorted_global_dets[local_idx]] = (int )local_idx;
121+ for (size_t global_d = 0 ; global_d < total_global_detectors; ++global_d) {
122+ cd.global_to_local_det [global_d] = (int )global_d;
101123 }
102124
103125 stim::DetectorErrorModel local_dem;
104- // MUST append detector instructions for ALL local detectors first to set count_detectors()
105- // correctly
106- for (size_t local_idx = 0 ; local_idx < sorted_global_dets.size (); ++local_idx) {
107- int global_d = sorted_global_dets[local_idx];
126+ for (size_t global_d = 0 ; global_d < total_global_detectors; ++global_d) {
108127 std::vector<double > c =
109128 coords_map.count (global_d) ? coords_map.at (global_d) : std::vector<double >{};
110129 std::string t = tags.count (global_d) ? tags.at (global_d) : " " ;
111- local_dem.append_detector_instruction (c, stim::DemTarget::relative_detector_id (local_idx ), t);
130+ local_dem.append_detector_instruction (c, stim::DemTarget::relative_detector_id (global_d ), t);
112131 }
113132
114133 for (const auto & inst : component_dems_raw[i].instructions ) {
115134 if (inst.type == stim::DemInstructionType::DEM_ERROR ) {
116- std::vector<stim::DemTarget> local_targets;
117135 bool has_obs = false ;
118136 for (const auto & t : inst.target_data ) {
119- if (t.is_relative_detector_id ()) {
120- int global_d = t.val ();
121- local_targets.push_back (
122- stim::DemTarget::relative_detector_id (cd.global_to_local_det .at (global_d)));
123- } else {
124- local_targets.push_back (t);
125- if (t.is_observable_id ()) has_obs = true ;
126- }
137+ if (t.is_observable_id ()) has_obs = true ;
127138 }
128139 if (has_obs) cd.affects_observable = true ;
129- local_dem.append_error_instruction (inst.arg_data [0 ], local_targets , inst.tag );
140+ local_dem.append_error_instruction (inst.arg_data [0 ], inst. target_data , inst.tag );
130141 } else if (inst.type == stim::DemInstructionType::DEM_LOGICAL_OBSERVABLE ) {
131142 local_dem.append_dem_instruction (inst);
132143 }
133144 }
134145
135- // std::cout << "DEBUG: local_dem " << i << " : " << local_dem << std::endl;
136-
137146 TesseractConfig config = base_config;
138147 config.dem = local_dem;
139148 config.merge_errors = true ;
140- config.det_orders = build_det_orders (config.dem , num_det_orders, det_order_method, seed + i );
149+ config.det_orders = build_det_orders (config.dem , num_det_orders, det_order_method, seed);
141150
142151 cd.decoder = std::make_unique<TesseractDecoder>(config);
143- // std::cout << "DEBUG: Component " << i << " initialized with " << cd.decoder->errors.size() <<
144- // " errors and " << config.dem.count_detectors() << " detectors." << std::endl;
145- /*
146- for (size_t ei = 0; ei < cd.decoder->errors.size(); ei++) {
147- // std::cout << " Comp " << i << " Err " << ei << ": D";
148- for (int d : cd.decoder->errors[ei].symptom.detectors) // std::cout << d << " ";
149- // std::cout << std::endl;
152+ if (base_config.verbose ) {
153+ std::cout << " DEBUG: Component " << i << " initialized with " << cd.decoder ->errors .size ()
154+ << " errors and " << config.dem .count_detectors () << " detectors." << std::endl;
150155 }
151- */
152156 cd.error_index_to_rules .resize (cd.decoder ->errors .size ());
153157
154158 for (size_t ei = 0 ; ei < cd.decoder ->errors .size (); ++ei) {
155159 cd.original_costs .push_back (cd.decoder ->errors [ei].likelihood_cost );
156- Hyperedge local_symptom = cd.decoder ->errors [ei].symptom .detectors ;
157- Hyperedge global_symptom;
158- for (int local_d : local_symptom) global_symptom.push_back (sorted_global_dets[local_d]);
160+ Hyperedge global_symptom = cd.decoder ->errors [ei].symptom .detectors ;
159161 std::sort (global_symptom.begin (), global_symptom.end ());
160- cd.symptom_to_error_index [global_symptom] = ei ;
162+ cd.symptom_to_error_index [global_symptom]. push_back (ei) ;
161163 }
162164 }
163165
@@ -167,19 +169,27 @@ void MultiPassTesseractDecoder::initialize(const stim::DetectorErrorModel& dem,
167169 int causal_comp = -1 ;
168170 if (!causal_symptom.empty ()) causal_comp = global_det_to_comp_id[causal_symptom[0 ]];
169171 if (causal_comp == -1 ) continue ;
172+
170173 auto it = component_decoders[causal_comp].symptom_to_error_index .find (causal_symptom);
171174 if (it == component_decoders[causal_comp].symptom_to_error_index .end ()) continue ;
172- size_t causal_err_idx = it->second ;
173- for (const auto & imp : implied_probs) {
174- Hyperedge target_symptom = imp.affected_hyperedge ;
175- std::sort (target_symptom.begin (), target_symptom.end ());
176- int target_comp = -1 ;
177- if (!target_symptom.empty ()) target_comp = global_det_to_comp_id[target_symptom[0 ]];
178- if (target_comp == -1 ) continue ;
179- auto t_it = component_decoders[target_comp].symptom_to_error_index .find (target_symptom);
180- if (t_it != component_decoders[target_comp].symptom_to_error_index .end ()) {
181- component_decoders[causal_comp].error_index_to_rules [causal_err_idx].push_back (
182- {(size_t )target_comp, t_it->second , imp.probability });
175+
176+ // Loop through all degenerate causal error indices!
177+ for (size_t causal_err_idx : it->second ) {
178+ for (const auto & imp : implied_probs) {
179+ Hyperedge target_symptom = imp.affected_hyperedge ;
180+ std::sort (target_symptom.begin (), target_symptom.end ());
181+ int target_comp = -1 ;
182+ if (!target_symptom.empty ()) target_comp = global_det_to_comp_id[target_symptom[0 ]];
183+ if (target_comp == -1 ) continue ;
184+
185+ auto t_it = component_decoders[target_comp].symptom_to_error_index .find (target_symptom);
186+ if (t_it != component_decoders[target_comp].symptom_to_error_index .end ()) {
187+ // Loop through all degenerate target error indices and add rules to each!
188+ for (size_t target_err_idx : t_it->second ) {
189+ component_decoders[causal_comp].error_index_to_rules [causal_err_idx].push_back (
190+ {(size_t )target_comp, target_err_idx, imp.probability });
191+ }
192+ }
183193 }
184194 }
185195 }
@@ -240,71 +250,99 @@ void MultiPassTesseractDecoder::build_causal_schedule() {
240250
241251std::vector<int > MultiPassTesseractDecoder::decode (const std::vector<uint64_t >& detections) {
242252 last_shot_num_reweights = 0 ;
243- // 1. Multi-Pass Loop: Earlier passes only bias the final pass.
253+
254+ // 1. Multi-Pass Loop: Sequentially schedules component passes and propagates priors.
244255 for (size_t pass = 0 ; pass < num_passes; ++pass) {
245256 bool is_final_pass = (pass == num_passes - 1 );
246257
258+ // Decode scheduled components for the current pass layer using persistent local buffers.
247259 for (size_t comp_idx : pass_schedule[pass]) {
248260 auto & cd = component_decoders[comp_idx];
249261 std::vector<uint64_t > local_dets;
250262 for (uint64_t d : detections) {
251- if (cd.global_to_local_det .count ((int )d)) {
252- local_dets.push_back (( uint64_t )cd. global_to_local_det . at (( int )d) );
263+ if (cd.component_detectors .count ((int )d)) {
264+ local_dets.push_back (d );
253265 }
254266 }
255267
256- // Perform decoding for this component in this pass.
257268 cd.decoder ->decode_to_errors (local_dets);
269+ component_predictions[comp_idx] = cd.decoder ->predicted_errors_buffer ;
270+ }
258271
259- if (is_final_pass) {
260- // Track components that decode in the final pass for extraction.
261- final_pass_active_components.push_back (comp_idx);
262- } else {
263- // If this is NOT the final pass, use the results for reweighting, then discard them.
272+ if (!is_final_pass) {
273+ // Step A: Apply Damped Fractional Memory to previously modified priors.
274+ // Smoothly decay current modifications back toward the baseline to prevent message
275+ // saturation.
276+ double gamma =
277+ 0.5 ; // Tunable decay factor: 1.0 is strict isolation, 0.0 is full accumulation.
278+
279+ for (size_t m_comp_idx : modified_component_indices) {
280+ auto & cd = component_decoders[m_comp_idx];
281+ if (!cd.shot_all_modified_error_indices .empty ()) {
282+ for (size_t idx : cd.shot_all_modified_error_indices ) {
283+ double baseline_cost = cd.original_costs [idx];
284+ double current_cost = cd.decoder ->errors [idx].likelihood_cost ;
285+ cd.decoder ->errors [idx].likelihood_cost =
286+ gamma * baseline_cost + (1.0 - gamma) * current_cost;
287+ }
288+ cd.decoder ->update_internal_costs (cd.shot_all_modified_error_indices );
289+ // Retain tracking indices so the final Surgical Reset completely clears cross-shot state.
290+ }
291+ }
292+
293+ // Step B: Broadcast reweighting rules derived strictly from the latest predictions.
294+ for (size_t comp_idx : pass_schedule[pass]) {
295+ auto & cd = component_decoders[comp_idx];
264296 for (size_t dem_err_idx : cd.decoder ->predicted_errors_buffer ) {
265297 size_t internal_err_idx = cd.decoder ->dem_error_to_error .at (dem_err_idx);
266298 if (internal_err_idx == std::numeric_limits<size_t >::max ()) continue ;
267299
268300 for (const auto & rule : cd.error_index_to_rules [internal_err_idx]) {
269301 auto & target_cd = component_decoders[rule.target_comp_idx ];
270302
271- // Track modified components only once per shot.
272- if (target_cd.modified_error_indices .empty ()) {
273- modified_component_indices.push_back (rule.target_comp_idx );
274- }
303+ modified_component_indices.push_back (rule.target_comp_idx );
275304
276- // Cap probability at 0.499 to prevent negative costs in the engine.
277- target_cd.decoder ->errors [rule.target_error_idx ].set_with_probability (
278- std::min (rule.conditional_prob , 0.499 ));
279- target_cd.modified_error_indices .push_back (rule.target_error_idx );
280- last_shot_num_reweights++;
305+ // Apply Max-Prob Rule safely for concurrent rules within this pass layer.
306+ double current_p = target_cd.decoder ->errors [rule.target_error_idx ].get_probability ();
307+ if (rule.conditional_prob > current_p) {
308+ target_cd.decoder ->errors [rule.target_error_idx ].set_with_probability (
309+ std::min (rule.conditional_prob , 0.5 ));
310+ target_cd.shot_all_modified_error_indices .push_back (rule.target_error_idx );
311+ last_shot_num_reweights++;
312+ }
281313 }
282314 }
283- // Clear the buffer so these intermediate decisions don't contribute to the final
284- // prediction.
285- cd.decoder ->predicted_errors_buffer .clear ();
286315 }
287- }
288316
289- // Sync modified costs for the next pass.
290- if (!is_final_pass) {
317+ // Step C: Deduplicate modified tracking vectors and synchronize internal graph costs.
318+ std::sort (modified_component_indices.begin (), modified_component_indices.end ());
319+ modified_component_indices.erase (
320+ std::unique (modified_component_indices.begin (), modified_component_indices.end ()),
321+ modified_component_indices.end ());
322+
291323 for (size_t m_comp_idx : modified_component_indices) {
292324 auto & cd = component_decoders[m_comp_idx];
293- if (!cd.modified_error_indices .empty ()) {
294- cd.decoder ->update_internal_costs (cd.modified_error_indices );
325+ if (!cd.shot_all_modified_error_indices .empty ()) {
326+ std::sort (cd.shot_all_modified_error_indices .begin (),
327+ cd.shot_all_modified_error_indices .end ());
328+ cd.shot_all_modified_error_indices .erase (
329+ std::unique (cd.shot_all_modified_error_indices .begin (),
330+ cd.shot_all_modified_error_indices .end ()),
331+ cd.shot_all_modified_error_indices .end ());
332+ cd.decoder ->update_internal_costs (cd.shot_all_modified_error_indices );
295333 }
296334 }
297335 }
298336 }
299337
300- // 2. Unified Logical Extraction: Collect final-pass predictions from only active components.
338+ // 2. Unified Logical Extraction: Collect final predictions from ALL components that ran during
339+ // the shot.
301340 std::set<int > flipped_observables;
302- for (size_t comp_idx : final_pass_active_components ) {
341+ for (const auto & [ comp_idx, preds] : component_predictions ) {
303342 auto & cd = component_decoders[comp_idx];
304- if (cd. decoder -> predicted_errors_buffer .empty ()) continue ;
343+ if (preds .empty ()) continue ;
305344
306- std::vector<int > local_flips =
307- cd.decoder ->get_flipped_observables (cd.decoder ->predicted_errors_buffer );
345+ std::vector<int > local_flips = cd.decoder ->get_flipped_observables (preds);
308346 for (int obs : local_flips) {
309347 if (flipped_observables.count (obs))
310348 flipped_observables.erase (obs);
@@ -313,17 +351,19 @@ std::vector<int> MultiPassTesseractDecoder::decode(const std::vector<uint64_t>&
313351 }
314352 }
315353
316- // 3. Surgical Reset: Restore modified costs for the next shot.
354+ // 3. Surgical Reset: Restore modified costs to leave the internal structures pristine for the
355+ // next shot.
317356 for (size_t m_comp_idx : modified_component_indices) {
318357 auto & cd = component_decoders[m_comp_idx];
319- for (size_t idx : cd.modified_error_indices ) {
320- cd.decoder ->errors [idx].likelihood_cost = cd.original_costs [idx];
358+ if (!cd.shot_all_modified_error_indices .empty ()) {
359+ for (size_t idx : cd.shot_all_modified_error_indices ) {
360+ cd.decoder ->errors [idx].likelihood_cost = cd.original_costs [idx];
361+ }
362+ cd.decoder ->update_internal_costs (cd.shot_all_modified_error_indices );
363+ cd.shot_all_modified_error_indices .clear ();
321364 }
322- cd.decoder ->update_internal_costs (cd.modified_error_indices );
323- cd.modified_error_indices .clear ();
324365 }
325366
326- // Clear shot-level tracking vectors.
327367 modified_component_indices.clear ();
328368 final_pass_active_components.clear ();
329369
0 commit comments