|
1 | 1 | #include "error_correlations.h" |
2 | | -#include <sstream> |
| 2 | + |
3 | 3 | #include <iostream> |
| 4 | +#include <sstream> |
4 | 5 |
|
5 | 6 | namespace tesseract { |
6 | 7 |
|
7 | 8 | std::string ImpliedProbability::str() const { |
8 | | - std::stringstream ss; |
9 | | - ss << "ImpliedProbability(affected={"; |
10 | | - for (size_t i = 0; i < affected_hyperedge.size(); ++i) { |
11 | | - ss << affected_hyperedge[i] << (i == affected_hyperedge.size() - 1 ? "" : ","); |
12 | | - } |
13 | | - ss << "}, prob=" << probability << ")"; |
14 | | - return ss.str(); |
| 9 | + std::stringstream ss; |
| 10 | + ss << "ImpliedProbability(affected={"; |
| 11 | + for (size_t i = 0; i < affected_hyperedge.size(); ++i) { |
| 12 | + ss << affected_hyperedge[i] << (i == affected_hyperedge.size() - 1 ? "" : ","); |
| 13 | + } |
| 14 | + ss << "}, prob=" << probability << ")"; |
| 15 | + return ss.str(); |
15 | 16 | } |
16 | 17 |
|
17 | 18 | bool ImpliedProbability::operator==(const ImpliedProbability& other) const { |
18 | | - return affected_hyperedge == other.affected_hyperedge && |
19 | | - std::abs(probability - other.probability) < 1e-12; |
| 19 | + return affected_hyperedge == other.affected_hyperedge && |
| 20 | + std::abs(probability - other.probability) < 1e-12; |
20 | 21 | } |
21 | 22 |
|
22 | 23 | bool ImpliedProbability::operator<(const ImpliedProbability& other) const { |
23 | | - if (affected_hyperedge != other.affected_hyperedge) { |
24 | | - return affected_hyperedge < other.affected_hyperedge; |
25 | | - } |
26 | | - return probability < other.probability; |
| 24 | + if (affected_hyperedge != other.affected_hyperedge) { |
| 25 | + return affected_hyperedge < other.affected_hyperedge; |
| 26 | + } |
| 27 | + return probability < other.probability; |
27 | 28 | } |
28 | 29 |
|
29 | | -JointProbsMap get_hyperedge_joint_probabilities(const stim::DetectorErrorModel& dem) { |
30 | | - JointProbsMap joint_probs; |
31 | | - auto flattened = dem.flattened(); |
32 | | - |
33 | | - for (const auto& inst : flattened.instructions) { |
34 | | - if (inst.type != stim::DemInstructionType::DEM_ERROR) continue; |
35 | | - |
36 | | - double p = inst.arg_data[0]; |
37 | | - |
38 | | - std::vector<Hyperedge> components; |
39 | | - size_t group_start = 0; |
40 | | - for (size_t k = 0; k <= inst.target_data.size(); ++k) { |
41 | | - if (k == inst.target_data.size() || inst.target_data[k].is_separator()) { |
42 | | - Hyperedge hyperedge; |
43 | | - for (size_t j = group_start; j < k; ++j) { |
44 | | - const auto& target = inst.target_data[j]; |
45 | | - if (target.is_relative_detector_id()) { |
46 | | - hyperedge.push_back(target.val()); |
47 | | - } |
48 | | - } |
49 | | - if (!hyperedge.empty()) { |
50 | | - std::sort(hyperedge.begin(), hyperedge.end()); |
51 | | - components.push_back(hyperedge); |
52 | | - } |
53 | | - group_start = k + 1; |
54 | | - } |
55 | | - } |
| 30 | +JointProbsMap get_hyperedge_joint_probabilities(const stim::DetectorErrorModel& dem, |
| 31 | + const std::vector<int>& global_det_to_comp_id) { |
| 32 | + JointProbsMap joint_probs; |
| 33 | + auto flattened = dem.flattened(); |
56 | 34 |
|
57 | | - // 1. Marginal probabilities (diagonal) |
58 | | - for (const auto& h : components) { |
59 | | - if (joint_probs[h].find(h) == joint_probs[h].end()) { |
60 | | - joint_probs[h][h] = 0.0; |
61 | | - } |
62 | | - // P(A) = P(A) XOR p |
63 | | - joint_probs[h][h] = joint_probs[h][h] * (1 - p) + p * (1 - joint_probs[h][h]); |
64 | | - } |
| 35 | + for (const auto& inst : flattened.instructions) { |
| 36 | + if (inst.type != stim::DemInstructionType::DEM_ERROR) continue; |
| 37 | + |
| 38 | + double p = inst.arg_data[0]; |
| 39 | + |
| 40 | + std::map<int, Hyperedge> comp_targets; |
| 41 | + for (const auto& target : inst.target_data) { |
| 42 | + if (target.is_relative_detector_id()) { |
| 43 | + int d = target.val(); |
| 44 | + int cid = |
| 45 | + (d >= 0 && (size_t)d < global_det_to_comp_id.size()) ? global_det_to_comp_id[d] : -1; |
| 46 | + if (cid != -1) comp_targets[cid].push_back(d); |
| 47 | + } |
| 48 | + } |
| 49 | + |
| 50 | + std::vector<Hyperedge> components; |
| 51 | + for (auto& [cid, h] : comp_targets) { |
| 52 | + std::sort(h.begin(), h.end()); |
| 53 | + components.push_back(h); |
| 54 | + } |
| 55 | + |
| 56 | + // 1. Marginal probabilities (diagonal) |
| 57 | + for (const auto& h : components) { |
| 58 | + if (joint_probs[h].find(h) == joint_probs[h].end()) { |
| 59 | + joint_probs[h][h] = 0.0; |
| 60 | + } |
| 61 | + // P(A) = P(A) XOR p |
| 62 | + joint_probs[h][h] = joint_probs[h][h] * (1 - p) + p * (1 - joint_probs[h][h]); |
| 63 | + } |
65 | 64 |
|
66 | | - // 2. Joint probabilities (off-diagonal) |
67 | | - // For a bridging error p connecting A and B, P(A and B) += p (approx) |
68 | | - // Actually, the joint probability is accurately tracked via the same XOR logic |
69 | | - // if we assume independence of other error mechanisms. |
70 | | - if (components.size() > 1) { |
71 | | - for (size_t i = 0; i < components.size(); ++i) { |
72 | | - for (size_t j = 0; j < components.size(); ++j) { |
73 | | - if (i == j) continue; |
74 | | - const auto& hi = components[i]; |
75 | | - const auto& hj = components[j]; |
76 | | - if (joint_probs[hi].find(hj) == joint_probs[hi].end()) { |
77 | | - joint_probs[hi][hj] = 0.0; |
78 | | - } |
79 | | - // For small p, joint probability P(A and B) is roughly the sum of p's of bridging errors |
80 | | - joint_probs[hi][hj] = joint_probs[hi][hj] * (1 - p) + p * (1 - joint_probs[hi][hj]); |
81 | | - } |
82 | | - } |
| 65 | + // 2. Joint probabilities (off-diagonal) |
| 66 | + // For a bridging error p connecting A and B, P(A and B) += p (approx) |
| 67 | + // Actually, the joint probability is accurately tracked via the same XOR logic |
| 68 | + // if we assume independence of other error mechanisms. |
| 69 | + if (components.size() > 1) { |
| 70 | + for (size_t i = 0; i < components.size(); ++i) { |
| 71 | + for (size_t j = 0; j < components.size(); ++j) { |
| 72 | + if (i == j) continue; |
| 73 | + const auto& hi = components[i]; |
| 74 | + const auto& hj = components[j]; |
| 75 | + if (joint_probs[hi].find(hj) == joint_probs[hi].end()) { |
| 76 | + joint_probs[hi][hj] = 0.0; |
| 77 | + } |
| 78 | + // For small p, joint probability P(A and B) is roughly the sum of p's of bridging errors |
| 79 | + joint_probs[hi][hj] = joint_probs[hi][hj] * (1 - p) + p * (1 - joint_probs[hi][hj]); |
83 | 80 | } |
| 81 | + } |
84 | 82 | } |
| 83 | + } |
85 | 84 |
|
86 | | - return joint_probs; |
| 85 | + return joint_probs; |
87 | 86 | } |
88 | 87 |
|
89 | 88 | ImpliedProbsMap get_implied_hyperedge_probabilities(const JointProbsMap& joint_probs) { |
90 | | - ImpliedProbsMap implied_probs; |
| 89 | + ImpliedProbsMap implied_probs; |
91 | 90 |
|
92 | | - for (const auto& [causal, affected_map] : joint_probs) { |
93 | | - double p_causal = 0.0; |
94 | | - auto it_self = affected_map.find(causal); |
95 | | - if (it_self != affected_map.end()) { |
96 | | - p_causal = it_self->second; |
97 | | - } |
| 91 | + for (const auto& [causal, affected_map] : joint_probs) { |
| 92 | + double p_causal = 0.0; |
| 93 | + auto it_self = affected_map.find(causal); |
| 94 | + if (it_self != affected_map.end()) { |
| 95 | + p_causal = it_self->second; |
| 96 | + } |
98 | 97 |
|
99 | | - if (p_causal <= 0 || p_causal >= 1.0) continue; |
| 98 | + if (p_causal <= 0 || p_causal >= 1.0) continue; |
100 | 99 |
|
101 | | - for (const auto& [affected, p_joint] : affected_map) { |
102 | | - if (causal == affected) continue; |
| 100 | + for (const auto& [affected, p_joint] : affected_map) { |
| 101 | + if (causal == affected) continue; |
103 | 102 |
|
104 | | - // Conditional Probability P(affected | causal) = P(affected and causal) / P(causal) |
105 | | - double p_conditional = p_joint / p_causal; |
106 | | - |
107 | | - // Cap to 1.0 (numerical precision) |
108 | | - if (p_conditional > 1.0) p_conditional = 1.0; |
109 | | - if (p_conditional < 0.0) p_conditional = 0.0; |
| 103 | + // Conditional Probability P(affected | causal) = P(affected and causal) / P(causal) |
| 104 | + double p_conditional = p_joint / p_causal; |
110 | 105 |
|
111 | | - implied_probs[causal].push_back({affected, p_conditional}); |
112 | | - } |
| 106 | + // Cap to 1.0 (numerical precision) |
| 107 | + if (p_conditional > 1.0) p_conditional = 1.0; |
| 108 | + if (p_conditional < 0.0) p_conditional = 0.0; |
| 109 | + |
| 110 | + implied_probs[causal].push_back({affected, p_conditional}); |
113 | 111 | } |
| 112 | + } |
114 | 113 |
|
115 | | - return implied_probs; |
| 114 | + return implied_probs; |
116 | 115 | } |
117 | 116 |
|
118 | | -ImpliedProbsMap process_dem_correlations(const stim::DetectorErrorModel& dem) { |
119 | | - auto joint = get_hyperedge_joint_probabilities(dem); |
120 | | - return get_implied_hyperedge_probabilities(joint); |
| 117 | +ImpliedProbsMap process_dem_correlations(const stim::DetectorErrorModel& dem, |
| 118 | + const std::vector<int>& global_det_to_comp_id) { |
| 119 | + auto joint = get_hyperedge_joint_probabilities(dem, global_det_to_comp_id); |
| 120 | + return get_implied_hyperedge_probabilities(joint); |
121 | 121 | } |
122 | 122 |
|
123 | | -} // namespace tesseract |
| 123 | +} // namespace tesseract |
0 commit comments