Skip to content

Commit 83ade9d

Browse files
committed
simplify wide path
1 parent 4bd1661 commit 83ade9d

1 file changed

Lines changed: 76 additions & 117 deletions

File tree

src/tesseract_trellis.cc

Lines changed: 76 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -57,29 +57,23 @@ struct Fault {
5757
std::vector<int> detectors;
5858
};
5959

60-
struct WideStateGroup {
61-
double mass;
62-
double score;
63-
size_t begin;
64-
size_t end;
65-
};
66-
6760
template <size_t Words>
6861
using FixedWideStateWords = std::array<uint64_t, Words>;
6962

7063
template <size_t Words>
71-
struct FixedWidePackedMass {
64+
struct FixedWideStateEntry {
7265
FixedWideStateWords<Words> state_words{};
73-
uint64_t obs_mask = 0;
74-
double mass = 0.0;
66+
double mass0 = 0.0;
67+
double mass1 = 0.0;
7568
double penalty = 0.0;
69+
double score = -INF;
7670
};
7771

7872
template <size_t Words>
7973
struct CompiledWideLayerTemplate {
8074
double q = 0.0;
8175
double p = 0.0;
82-
uint64_t obs_mask = 0;
76+
bool toggles_observable = false;
8377
std::array<uint64_t, Words> surviving_masks{};
8478
std::array<uint8_t, Words> projection_dst_words{};
8579
std::array<uint8_t, Words> projection_dst_offsets{};
@@ -314,6 +308,11 @@ double score_mass_and_penalty(double mass, double penalty,
314308
return std::log(mass) - penalty;
315309
}
316310

311+
template <size_t Words>
312+
TESSERACT_ALWAYS_INLINE double total_entry_mass(const FixedWideStateEntry<Words>& entry) {
313+
return entry.mass0 + entry.mass1;
314+
}
315+
317316
void reset_kept_state_stats(TesseractTrellisDecoder* decoder) {
318317
decoder->kept_state_sample_count = 0;
319318
decoder->kept_state_min = 0;
@@ -489,40 +488,35 @@ BranchPenaltyUpdate compute_compiled_wide_branch_update(
489488
}
490489

491490
template <size_t Words>
492-
void normalize_compiled_items(std::vector<FixedWidePackedMass<Words>>* items) {
491+
void normalize_compiled_items(std::vector<FixedWideStateEntry<Words>>* items) {
493492
double total = 0.0;
494493
for (const auto& item : *items) {
495-
total += item.mass;
494+
total += total_entry_mass(item);
496495
}
497496
if (total == 0.0) {
498497
return;
499498
}
500499
for (auto& item : *items) {
501-
item.mass /= total;
500+
item.mass0 /= total;
501+
item.mass1 /= total;
502502
}
503503
}
504504

505505
template <size_t Words>
506-
void merge_equal_compiled_keys_inplace(std::vector<FixedWidePackedMass<Words>>* items) {
506+
void merge_equal_compiled_keys_inplace(std::vector<FixedWideStateEntry<Words>>* items) {
507507
if (items->empty()) {
508508
return;
509509
}
510510
std::sort(items->begin(), items->end(),
511-
[](const FixedWidePackedMass<Words>& a, const FixedWidePackedMass<Words>& b) {
512-
if (fixed_wide_state_less(a.state_words, b.state_words)) {
513-
return true;
514-
}
515-
if (fixed_wide_state_less(b.state_words, a.state_words)) {
516-
return false;
517-
}
518-
return a.obs_mask < b.obs_mask;
511+
[](const FixedWideStateEntry<Words>& a, const FixedWideStateEntry<Words>& b) {
512+
return fixed_wide_state_less(a.state_words, b.state_words);
519513
});
520514

521515
size_t out = 0;
522516
for (size_t i = 1; i < items->size(); ++i) {
523-
if ((*items)[i].obs_mask == (*items)[out].obs_mask &&
524-
(*items)[i].state_words == (*items)[out].state_words) {
525-
(*items)[out].mass += (*items)[i].mass;
517+
if ((*items)[i].state_words == (*items)[out].state_words) {
518+
(*items)[out].mass0 += (*items)[i].mass0;
519+
(*items)[out].mass1 += (*items)[i].mass1;
526520
} else {
527521
++out;
528522
if (out != i) {
@@ -534,111 +528,61 @@ void merge_equal_compiled_keys_inplace(std::vector<FixedWidePackedMass<Words>>*
534528
}
535529

536530
template <size_t Words>
537-
bool compiled_wide_state_group_score_greater(const std::vector<FixedWidePackedMass<Words>>& entries,
538-
const WideStateGroup& a, const WideStateGroup& b) {
531+
bool compiled_state_score_greater(const FixedWideStateEntry<Words>& a,
532+
const FixedWideStateEntry<Words>& b) {
539533
if (a.score != b.score) {
540534
return a.score > b.score;
541535
}
542-
return fixed_wide_state_less(entries[a.begin].state_words, entries[b.begin].state_words);
536+
return fixed_wide_state_less(a.state_words, b.state_words);
543537
}
544538

545539
template <size_t Words>
546-
size_t trim_compiled_wide_state_groups_by_beam_and_mass(
547-
const std::vector<FixedWidePackedMass<Words>>& entries, std::vector<WideStateGroup>* groups,
548-
size_t beam_width, double beam_eps) {
549-
if (groups->empty()) {
540+
size_t keep_top_compiled_states(std::vector<FixedWideStateEntry<Words>>* entries,
541+
size_t beam_width, double beam_eps,
542+
TesseractTrellisRankingMode ranking_mode) {
543+
if (entries->empty()) {
550544
return 0;
551545
}
552546

553547
double total_mass = 0.0;
554-
if (beam_eps > 0.0) {
555-
for (const auto& group : *groups) {
556-
total_mass += group.mass;
548+
for (auto& entry : *entries) {
549+
const double mass = total_entry_mass(entry);
550+
entry.score = score_mass_and_penalty(mass, entry.penalty, ranking_mode);
551+
if (beam_eps > 0.0) {
552+
total_mass += mass;
557553
}
558554
}
559555

560-
if (groups->size() > beam_width) {
561-
std::nth_element(groups->begin(), groups->begin() + beam_width, groups->end(),
562-
[&entries](const WideStateGroup& a, const WideStateGroup& b) {
563-
return compiled_wide_state_group_score_greater(entries, a, b);
556+
if (entries->size() > beam_width) {
557+
std::nth_element(entries->begin(), entries->begin() + beam_width, entries->end(),
558+
[](const FixedWideStateEntry<Words>& a, const FixedWideStateEntry<Words>& b) {
559+
return compiled_state_score_greater(a, b);
564560
});
565-
groups->resize(beam_width);
561+
entries->resize(beam_width);
566562
} else if (beam_eps <= 0.0) {
567-
return groups->size();
563+
return entries->size();
568564
}
569565

570566
if (beam_eps <= 0.0 || total_mass <= 0.0) {
571-
return groups->size();
567+
return entries->size();
572568
}
573569

574-
std::sort(groups->begin(), groups->end(),
575-
[&entries](const WideStateGroup& a, const WideStateGroup& b) {
576-
return compiled_wide_state_group_score_greater(entries, a, b);
570+
std::sort(entries->begin(), entries->end(),
571+
[](const FixedWideStateEntry<Words>& a, const FixedWideStateEntry<Words>& b) {
572+
return compiled_state_score_greater(a, b);
577573
});
578574
const double retained_target_mass = total_mass * (1.0 - beam_eps);
579575
double retained_mass = 0.0;
580576
size_t keep_count = 0;
581-
while (keep_count < groups->size()) {
582-
retained_mass += (*groups)[keep_count].mass;
577+
while (keep_count < entries->size()) {
578+
retained_mass += total_entry_mass((*entries)[keep_count]);
583579
++keep_count;
584580
if (retained_mass >= retained_target_mass) {
585581
break;
586582
}
587583
}
588-
groups->resize(keep_count);
589-
std::sort(groups->begin(), groups->end(),
590-
[](const WideStateGroup& a, const WideStateGroup& b) { return a.begin < b.begin; });
591-
return groups->size();
592-
}
593-
594-
template <size_t Words>
595-
std::vector<WideStateGroup> collect_compiled_wide_state_groups(
596-
const std::vector<FixedWidePackedMass<Words>>& entries,
597-
TesseractTrellisRankingMode ranking_mode) {
598-
std::vector<WideStateGroup> groups;
599-
if (entries.empty()) {
600-
return groups;
601-
}
602-
groups.reserve(entries.size());
603-
size_t begin = 0;
604-
while (begin < entries.size()) {
605-
double mass = 0.0;
606-
size_t end = begin;
607-
while (end < entries.size() && entries[end].state_words == entries[begin].state_words) {
608-
mass += entries[end].mass;
609-
++end;
610-
}
611-
groups.push_back(
612-
{mass, score_mass_and_penalty(mass, entries[begin].penalty, ranking_mode), begin, end});
613-
begin = end;
614-
}
615-
return groups;
616-
}
617-
618-
template <size_t Words>
619-
size_t keep_top_compiled_states(std::vector<FixedWidePackedMass<Words>>* entries,
620-
size_t beam_width, double beam_eps,
621-
TesseractTrellisRankingMode ranking_mode) {
622-
if (entries->empty()) {
623-
return 0;
624-
}
625-
auto groups = collect_compiled_wide_state_groups(*entries, ranking_mode);
626-
const size_t kept_group_count =
627-
trim_compiled_wide_state_groups_by_beam_and_mass(*entries, &groups, beam_width, beam_eps);
628-
629-
std::vector<FixedWidePackedMass<Words>> kept;
630-
size_t kept_entries = 0;
631-
for (const auto& group : groups) {
632-
kept_entries += group.end - group.begin;
633-
}
634-
kept.reserve(kept_entries);
635-
for (const auto& group : groups) {
636-
for (size_t k = group.begin; k < group.end; ++k) {
637-
kept.push_back(std::move((*entries)[k]));
638-
}
639-
}
640-
*entries = std::move(kept);
641-
return kept_group_count;
584+
entries->resize(keep_count);
585+
return keep_count;
642586
}
643587

644588
template <size_t Words>
@@ -655,7 +599,10 @@ std::vector<CompiledWideLayerTemplate<Words>> compile_wide_layers(
655599
CompiledWideLayerTemplate<Words> compiled;
656600
compiled.q = layer.q;
657601
compiled.p = layer.p;
658-
compiled.obs_mask = layer.obs_mask;
602+
if (layer.obs_mask > 1) {
603+
throw std::invalid_argument("tesseract_trellis currently supports at most 1 observable");
604+
}
605+
compiled.toggles_observable = layer.obs_mask != 0;
659606

660607
std::array<uint64_t, Words> surviving_masks{};
661608
for (uint32_t current_local : layer.surviving_local_indices) {
@@ -738,11 +685,11 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase {
738685
actual_detector_words);
739686
}
740687

741-
std::vector<FixedWidePackedMass<Words>> beam_entries;
742-
std::vector<FixedWidePackedMass<Words>> next_entries;
688+
std::vector<FixedWideStateEntry<Words>> beam_entries;
689+
std::vector<FixedWideStateEntry<Words>> next_entries;
743690
beam_entries.reserve(decoder->config.beam_width * 2 + 2);
744691
next_entries.reserve(decoder->config.beam_width * 4 + 4);
745-
beam_entries.push_back({{}, 0, 1.0, initial_penalty});
692+
beam_entries.push_back({{}, 1.0, 0.0, initial_penalty});
746693
decoder->max_beam_size_seen = 1;
747694

748695
const bool compute_penalties =
@@ -776,16 +723,28 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase {
776723
FixedWideStateWords<Words> projected_toggled = projected_state;
777724
xor_compiled_wide_state(&projected_toggled, layer.projected_fault_mask_words);
778725
next_entries.push_back(
779-
{std::move(projected_state), item.obs_mask, item.mass * layer.q, update.absent_penalty});
780-
next_entries.push_back({std::move(projected_toggled), item.obs_mask ^ layer.obs_mask,
781-
item.mass * layer.p, update.present_penalty});
726+
{std::move(projected_state), item.mass0 * layer.q, item.mass1 * layer.q,
727+
update.absent_penalty});
728+
if (layer.toggles_observable) {
729+
next_entries.push_back({std::move(projected_toggled), item.mass1 * layer.p,
730+
item.mass0 * layer.p, update.present_penalty});
731+
} else {
732+
next_entries.push_back({std::move(projected_toggled), item.mass0 * layer.p,
733+
item.mass1 * layer.p, update.present_penalty});
734+
}
782735
} else if (keep_absent) {
783736
next_entries.push_back(
784-
{std::move(projected_state), item.obs_mask, item.mass * layer.q, update.absent_penalty});
737+
{std::move(projected_state), item.mass0 * layer.q, item.mass1 * layer.q,
738+
update.absent_penalty});
785739
} else if (keep_present) {
786740
xor_compiled_wide_state(&projected_state, layer.projected_fault_mask_words);
787-
next_entries.push_back({std::move(projected_state), item.obs_mask ^ layer.obs_mask,
788-
item.mass * layer.p, update.present_penalty});
741+
if (layer.toggles_observable) {
742+
next_entries.push_back({std::move(projected_state), item.mass1 * layer.p,
743+
item.mass0 * layer.p, update.present_penalty});
744+
} else {
745+
next_entries.push_back({std::move(projected_state), item.mass0 * layer.p,
746+
item.mass1 * layer.p, update.present_penalty});
747+
}
789748
}
790749
}
791750
auto t1 = std::chrono::high_resolution_clock::now();
@@ -820,11 +779,8 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase {
820779
if (!fixed_wide_state_zero(item.state_words)) {
821780
continue;
822781
}
823-
if (item.obs_mask == 0) {
824-
decoder->total_mass_obs0 += item.mass;
825-
} else if (item.obs_mask == 1) {
826-
decoder->total_mass_obs1 += item.mass;
827-
}
782+
decoder->total_mass_obs0 += item.mass0;
783+
decoder->total_mass_obs1 += item.mass1;
828784
}
829785
if (decoder->total_mass_obs0 == 0.0 && decoder->total_mass_obs1 == 0.0) {
830786
decoder->low_confidence_flag = true;
@@ -903,6 +859,9 @@ TesseractTrellisDecoder::TesseractTrellisDecoder(TesseractTrellisConfig config_)
903859
errors = get_errors_from_dem(config.dem.flattened());
904860
num_detectors = config.dem.count_detectors();
905861
num_observables = config.dem.count_observables();
862+
if (num_observables > 1) {
863+
throw std::invalid_argument("tesseract_trellis currently supports at most 1 observable");
864+
}
906865

907866
all_possible_detector_words.assign(num_state_words(num_detectors), 0);
908867
actual_detector_words_scratch.assign(all_possible_detector_words.size(), 0);

0 commit comments

Comments
 (0)