Skip to content

Commit e165045

Browse files
committed
fix: resolve C++ detector-to-component correlation analyzer bug
TAG=agy
1 parent 1217aa7 commit e165045

4 files changed

Lines changed: 134 additions & 104 deletions

File tree

src/error_correlations.cc

Lines changed: 90 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,123 +1,123 @@
11
#include "error_correlations.h"
2-
#include <sstream>
2+
33
#include <iostream>
4+
#include <sstream>
45

56
namespace tesseract {
67

78
std::string ImpliedProbability::str() const {
8-
std::stringstream ss;
9-
ss << "ImpliedProbability(affected={";
10-
for (size_t i = 0; i < affected_hyperedge.size(); ++i) {
11-
ss << affected_hyperedge[i] << (i == affected_hyperedge.size() - 1 ? "" : ",");
12-
}
13-
ss << "}, prob=" << probability << ")";
14-
return ss.str();
9+
std::stringstream ss;
10+
ss << "ImpliedProbability(affected={";
11+
for (size_t i = 0; i < affected_hyperedge.size(); ++i) {
12+
ss << affected_hyperedge[i] << (i == affected_hyperedge.size() - 1 ? "" : ",");
13+
}
14+
ss << "}, prob=" << probability << ")";
15+
return ss.str();
1516
}
1617

1718
bool ImpliedProbability::operator==(const ImpliedProbability& other) const {
18-
return affected_hyperedge == other.affected_hyperedge &&
19-
std::abs(probability - other.probability) < 1e-12;
19+
return affected_hyperedge == other.affected_hyperedge &&
20+
std::abs(probability - other.probability) < 1e-12;
2021
}
2122

2223
bool ImpliedProbability::operator<(const ImpliedProbability& other) const {
23-
if (affected_hyperedge != other.affected_hyperedge) {
24-
return affected_hyperedge < other.affected_hyperedge;
25-
}
26-
return probability < other.probability;
24+
if (affected_hyperedge != other.affected_hyperedge) {
25+
return affected_hyperedge < other.affected_hyperedge;
26+
}
27+
return probability < other.probability;
2728
}
2829

29-
JointProbsMap get_hyperedge_joint_probabilities(const stim::DetectorErrorModel& dem) {
30-
JointProbsMap joint_probs;
31-
auto flattened = dem.flattened();
32-
33-
for (const auto& inst : flattened.instructions) {
34-
if (inst.type != stim::DemInstructionType::DEM_ERROR) continue;
35-
36-
double p = inst.arg_data[0];
37-
38-
std::vector<Hyperedge> components;
39-
size_t group_start = 0;
40-
for (size_t k = 0; k <= inst.target_data.size(); ++k) {
41-
if (k == inst.target_data.size() || inst.target_data[k].is_separator()) {
42-
Hyperedge hyperedge;
43-
for (size_t j = group_start; j < k; ++j) {
44-
const auto& target = inst.target_data[j];
45-
if (target.is_relative_detector_id()) {
46-
hyperedge.push_back(target.val());
47-
}
48-
}
49-
if (!hyperedge.empty()) {
50-
std::sort(hyperedge.begin(), hyperedge.end());
51-
components.push_back(hyperedge);
52-
}
53-
group_start = k + 1;
54-
}
55-
}
30+
JointProbsMap get_hyperedge_joint_probabilities(const stim::DetectorErrorModel& dem,
31+
const std::vector<int>& global_det_to_comp_id) {
32+
JointProbsMap joint_probs;
33+
auto flattened = dem.flattened();
5634

57-
// 1. Marginal probabilities (diagonal)
58-
for (const auto& h : components) {
59-
if (joint_probs[h].find(h) == joint_probs[h].end()) {
60-
joint_probs[h][h] = 0.0;
61-
}
62-
// P(A) = P(A) XOR p
63-
joint_probs[h][h] = joint_probs[h][h] * (1 - p) + p * (1 - joint_probs[h][h]);
64-
}
35+
for (const auto& inst : flattened.instructions) {
36+
if (inst.type != stim::DemInstructionType::DEM_ERROR) continue;
37+
38+
double p = inst.arg_data[0];
39+
40+
std::map<int, Hyperedge> comp_targets;
41+
for (const auto& target : inst.target_data) {
42+
if (target.is_relative_detector_id()) {
43+
int d = target.val();
44+
int cid =
45+
(d >= 0 && (size_t)d < global_det_to_comp_id.size()) ? global_det_to_comp_id[d] : -1;
46+
if (cid != -1) comp_targets[cid].push_back(d);
47+
}
48+
}
49+
50+
std::vector<Hyperedge> components;
51+
for (auto& [cid, h] : comp_targets) {
52+
std::sort(h.begin(), h.end());
53+
components.push_back(h);
54+
}
55+
56+
// 1. Marginal probabilities (diagonal)
57+
for (const auto& h : components) {
58+
if (joint_probs[h].find(h) == joint_probs[h].end()) {
59+
joint_probs[h][h] = 0.0;
60+
}
61+
// P(A) = P(A) XOR p
62+
joint_probs[h][h] = joint_probs[h][h] * (1 - p) + p * (1 - joint_probs[h][h]);
63+
}
6564

66-
// 2. Joint probabilities (off-diagonal)
67-
// For a bridging error p connecting A and B, P(A and B) += p (approx)
68-
// Actually, the joint probability is accurately tracked via the same XOR logic
69-
// if we assume independence of other error mechanisms.
70-
if (components.size() > 1) {
71-
for (size_t i = 0; i < components.size(); ++i) {
72-
for (size_t j = 0; j < components.size(); ++j) {
73-
if (i == j) continue;
74-
const auto& hi = components[i];
75-
const auto& hj = components[j];
76-
if (joint_probs[hi].find(hj) == joint_probs[hi].end()) {
77-
joint_probs[hi][hj] = 0.0;
78-
}
79-
// For small p, joint probability P(A and B) is roughly the sum of p's of bridging errors
80-
joint_probs[hi][hj] = joint_probs[hi][hj] * (1 - p) + p * (1 - joint_probs[hi][hj]);
81-
}
82-
}
65+
// 2. Joint probabilities (off-diagonal)
66+
// For a bridging error p connecting A and B, P(A and B) += p (approx)
67+
// Actually, the joint probability is accurately tracked via the same XOR logic
68+
// if we assume independence of other error mechanisms.
69+
if (components.size() > 1) {
70+
for (size_t i = 0; i < components.size(); ++i) {
71+
for (size_t j = 0; j < components.size(); ++j) {
72+
if (i == j) continue;
73+
const auto& hi = components[i];
74+
const auto& hj = components[j];
75+
if (joint_probs[hi].find(hj) == joint_probs[hi].end()) {
76+
joint_probs[hi][hj] = 0.0;
77+
}
78+
// For small p, joint probability P(A and B) is roughly the sum of p's of bridging errors
79+
joint_probs[hi][hj] = joint_probs[hi][hj] * (1 - p) + p * (1 - joint_probs[hi][hj]);
8380
}
81+
}
8482
}
83+
}
8584

86-
return joint_probs;
85+
return joint_probs;
8786
}
8887

8988
ImpliedProbsMap get_implied_hyperedge_probabilities(const JointProbsMap& joint_probs) {
90-
ImpliedProbsMap implied_probs;
89+
ImpliedProbsMap implied_probs;
9190

92-
for (const auto& [causal, affected_map] : joint_probs) {
93-
double p_causal = 0.0;
94-
auto it_self = affected_map.find(causal);
95-
if (it_self != affected_map.end()) {
96-
p_causal = it_self->second;
97-
}
91+
for (const auto& [causal, affected_map] : joint_probs) {
92+
double p_causal = 0.0;
93+
auto it_self = affected_map.find(causal);
94+
if (it_self != affected_map.end()) {
95+
p_causal = it_self->second;
96+
}
9897

99-
if (p_causal <= 0 || p_causal >= 1.0) continue;
98+
if (p_causal <= 0 || p_causal >= 1.0) continue;
10099

101-
for (const auto& [affected, p_joint] : affected_map) {
102-
if (causal == affected) continue;
100+
for (const auto& [affected, p_joint] : affected_map) {
101+
if (causal == affected) continue;
103102

104-
// Conditional Probability P(affected | causal) = P(affected and causal) / P(causal)
105-
double p_conditional = p_joint / p_causal;
106-
107-
// Cap to 1.0 (numerical precision)
108-
if (p_conditional > 1.0) p_conditional = 1.0;
109-
if (p_conditional < 0.0) p_conditional = 0.0;
103+
// Conditional Probability P(affected | causal) = P(affected and causal) / P(causal)
104+
double p_conditional = p_joint / p_causal;
110105

111-
implied_probs[causal].push_back({affected, p_conditional});
112-
}
106+
// Cap to 1.0 (numerical precision)
107+
if (p_conditional > 1.0) p_conditional = 1.0;
108+
if (p_conditional < 0.0) p_conditional = 0.0;
109+
110+
implied_probs[causal].push_back({affected, p_conditional});
113111
}
112+
}
114113

115-
return implied_probs;
114+
return implied_probs;
116115
}
117116

118-
ImpliedProbsMap process_dem_correlations(const stim::DetectorErrorModel& dem) {
119-
auto joint = get_hyperedge_joint_probabilities(dem);
120-
return get_implied_hyperedge_probabilities(joint);
117+
ImpliedProbsMap process_dem_correlations(const stim::DetectorErrorModel& dem,
118+
const std::vector<int>& global_det_to_comp_id) {
119+
auto joint = get_hyperedge_joint_probabilities(dem, global_det_to_comp_id);
120+
return get_implied_hyperedge_probabilities(joint);
121121
}
122122

123-
} // namespace tesseract
123+
} // namespace tesseract

src/error_correlations.h

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
#ifndef ERROR_CORRELATIONS_H
22
#define ERROR_CORRELATIONS_H
33

4+
#include <algorithm>
45
#include <cmath>
56
#include <map>
67
#include <numeric>
7-
#include <vector>
88
#include <string>
9-
#include <algorithm>
9+
#include <vector>
1010

1111
#include "stim.h"
1212

@@ -16,27 +16,29 @@ namespace tesseract {
1616
* Represents a probability adjustment for an affected hyperedge given a causal hyperedge.
1717
*/
1818
struct ImpliedProbability {
19-
std::vector<int> affected_hyperedge;
20-
double probability; // Represents the conditional probability P(affected | causal)
19+
std::vector<int> affected_hyperedge;
20+
double probability; // Represents the conditional probability P(affected | causal)
2121

22-
std::string str() const;
23-
bool operator==(const ImpliedProbability& other) const;
24-
bool operator<(const ImpliedProbability& other) const;
22+
std::string str() const;
23+
bool operator==(const ImpliedProbability& other) const;
24+
bool operator<(const ImpliedProbability& other) const;
2525
};
2626

2727
// Type alias for hyperedge (sorted detector indices)
2828
using Hyperedge = std::vector<int>;
2929
// Type alias for joint probabilities map: causal_hyperedge -> {affected_hyperedge -> joint_prob}
3030
using JointProbsMap = std::map<Hyperedge, std::map<Hyperedge, double>>;
31-
// Type alias for implied probabilities map: causal_hyperedge -> list of conditional probability updates
31+
// Type alias for implied probabilities map: causal_hyperedge -> list of conditional probability
32+
// updates
3233
using ImpliedProbsMap = std::map<Hyperedge, std::vector<ImpliedProbability>>;
3334

3435
/**
3536
* Calculates marginal and joint probabilities for hyperedges in a DEM.
36-
* Note: Assumes the input DEM has NOT been decomposed yet, as we need bridging errors
37+
* Note: Assumes the input DEM has NOT been decomposed yet, as we need bridging errors
3738
* to find joint probabilities.
3839
*/
39-
JointProbsMap get_hyperedge_joint_probabilities(const stim::DetectorErrorModel& dem);
40+
JointProbsMap get_hyperedge_joint_probabilities(const stim::DetectorErrorModel& dem,
41+
const std::vector<int>& global_det_to_comp_id);
4042

4143
/**
4244
* Calculates conditional probabilities from joint probabilities.
@@ -46,8 +48,9 @@ ImpliedProbsMap get_implied_hyperedge_probabilities(const JointProbsMap& joint_p
4648
/**
4749
* Complete workflow for analyzing correlations within a stim::DetectorErrorModel.
4850
*/
49-
ImpliedProbsMap process_dem_correlations(const stim::DetectorErrorModel& dem);
51+
ImpliedProbsMap process_dem_correlations(const stim::DetectorErrorModel& dem,
52+
const std::vector<int>& global_det_to_comp_id);
5053

51-
} // namespace tesseract
54+
} // namespace tesseract
5255

53-
#endif // ERROR_CORRELATIONS_H
56+
#endif // ERROR_CORRELATIONS_H

src/multi_pass_tesseract_decoder.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ void MultiPassTesseractDecoder::initialize(const stim::DetectorErrorModel& dem,
5353
// std::cout << "DEBUG decomposed:\n" << decomposed << std::endl;
5454
stim::DetectorErrorModel merged = merge_indistinguishable_errors(decomposed);
5555
// std::cout << "DEBUG merged:\n" << merged << std::endl;
56-
ImpliedProbsMap raw_correlations = process_dem_correlations(merged);
5756

5857
std::set<int> unique_classes;
5958
for (int c : detector_classes)
@@ -82,6 +81,8 @@ void MultiPassTesseractDecoder::initialize(const stim::DetectorErrorModel& dem,
8281
}
8382
}
8483

84+
ImpliedProbsMap raw_correlations = process_dem_correlations(flattened, global_det_to_comp_id);
85+
8586
auto component_dems_raw = split_dem_by_component(merged, [&](int d) {
8687
return (d >= 0 && (size_t)d < total_global_detectors) ? global_det_to_comp_id[d] : -1;
8788
});

src/multi_pass_tesseract_decoder.test.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,29 @@ TEST(MultiPassTesseractDecoderTest, BoundaryConditionAndCappingTest) {
269269
std::vector<uint64_t> hits = {0};
270270
ASSERT_NO_THROW(decoder.decode(hits));
271271
}
272+
273+
TEST(MultiPassTesseractDecoderTest, IntermediatePassLeakageTest) {
274+
stim::DetectorErrorModel dem(R"DEM(
275+
error(0.1) D0 D1 L0
276+
error(0.01) D0
277+
error(0.2) D1 L0
278+
detector D0
279+
detector D1
280+
logical_observable L0
281+
)DEM");
282+
283+
auto classifier = [](int index, const std::vector<double>& coords, const std::string& tag) -> int {
284+
return index;
285+
};
286+
287+
TesseractConfig config;
288+
config.dem = dem;
289+
290+
MultiPassTesseractDecoder decoder(dem, 3, classifier, config, 1, DetOrder::DetIndex, 12345, SchedulingStrategy::Causal);
291+
292+
std::vector<uint64_t> hits = {0};
293+
decoder.decode(hits);
294+
295+
// Rigorously assert that prior LLR reweights occurred successfully on the raw un-decomposed DEM!
296+
ASSERT_GT(decoder.get_last_shot_num_reweights(), 0);
297+
}

0 commit comments

Comments
 (0)