@@ -57,29 +57,23 @@ struct Fault {
5757 std::vector<int > detectors;
5858};
5959
60- struct WideStateGroup {
61- double mass;
62- double score;
63- size_t begin;
64- size_t end;
65- };
66-
6760template <size_t Words>
6861using FixedWideStateWords = std::array<uint64_t , Words>;
6962
7063template <size_t Words>
71- struct FixedWidePackedMass {
64+ struct FixedWideStateEntry {
7265 FixedWideStateWords<Words> state_words{};
73- uint64_t obs_mask = 0 ;
74- double mass = 0.0 ;
66+ double mass0 = 0. 0 ;
67+ double mass1 = 0.0 ;
7568 double penalty = 0.0 ;
69+ double score = -INF;
7670};
7771
7872template <size_t Words>
7973struct CompiledWideLayerTemplate {
8074 double q = 0.0 ;
8175 double p = 0.0 ;
82- uint64_t obs_mask = 0 ;
76+ bool toggles_observable = false ;
8377 std::array<uint64_t , Words> surviving_masks{};
8478 std::array<uint8_t , Words> projection_dst_words{};
8579 std::array<uint8_t , Words> projection_dst_offsets{};
@@ -314,6 +308,11 @@ double score_mass_and_penalty(double mass, double penalty,
314308 return std::log (mass) - penalty;
315309}
316310
311+ template <size_t Words>
312+ TESSERACT_ALWAYS_INLINE double total_entry_mass (const FixedWideStateEntry<Words>& entry) {
313+ return entry.mass0 + entry.mass1 ;
314+ }
315+
317316void reset_kept_state_stats (TesseractTrellisDecoder* decoder) {
318317 decoder->kept_state_sample_count = 0 ;
319318 decoder->kept_state_min = 0 ;
@@ -489,40 +488,35 @@ BranchPenaltyUpdate compute_compiled_wide_branch_update(
489488}
490489
491490template <size_t Words>
492- void normalize_compiled_items (std::vector<FixedWidePackedMass <Words>>* items) {
491+ void normalize_compiled_items (std::vector<FixedWideStateEntry <Words>>* items) {
493492 double total = 0.0 ;
494493 for (const auto & item : *items) {
495- total += item. mass ;
494+ total += total_entry_mass ( item) ;
496495 }
497496 if (total == 0.0 ) {
498497 return ;
499498 }
500499 for (auto & item : *items) {
501- item.mass /= total;
500+ item.mass0 /= total;
501+ item.mass1 /= total;
502502 }
503503}
504504
505505template <size_t Words>
506- void merge_equal_compiled_keys_inplace (std::vector<FixedWidePackedMass <Words>>* items) {
506+ void merge_equal_compiled_keys_inplace (std::vector<FixedWideStateEntry <Words>>* items) {
507507 if (items->empty ()) {
508508 return ;
509509 }
510510 std::sort (items->begin (), items->end (),
511- [](const FixedWidePackedMass<Words>& a, const FixedWidePackedMass<Words>& b) {
512- if (fixed_wide_state_less (a.state_words , b.state_words )) {
513- return true ;
514- }
515- if (fixed_wide_state_less (b.state_words , a.state_words )) {
516- return false ;
517- }
518- return a.obs_mask < b.obs_mask ;
511+ [](const FixedWideStateEntry<Words>& a, const FixedWideStateEntry<Words>& b) {
512+ return fixed_wide_state_less (a.state_words , b.state_words );
519513 });
520514
521515 size_t out = 0 ;
522516 for (size_t i = 1 ; i < items->size (); ++i) {
523- if ((*items)[i].obs_mask == (*items)[out].obs_mask &&
524- (*items)[i]. state_words == (*items)[out]. state_words ) {
525- (*items)[out].mass += (*items)[i].mass ;
517+ if ((*items)[i].state_words == (*items)[out].state_words ) {
518+ (*items)[out]. mass0 += (*items)[i]. mass0 ;
519+ (*items)[out].mass1 += (*items)[i].mass1 ;
526520 } else {
527521 ++out;
528522 if (out != i) {
@@ -534,111 +528,61 @@ void merge_equal_compiled_keys_inplace(std::vector<FixedWidePackedMass<Words>>*
534528}
535529
536530template <size_t Words>
537- bool compiled_wide_state_group_score_greater (const std::vector<FixedWidePackedMass< Words>>& entries ,
538- const WideStateGroup& a, const WideStateGroup & b) {
531+ bool compiled_state_score_greater (const FixedWideStateEntry< Words>& a ,
532+ const FixedWideStateEntry<Words> & b) {
539533 if (a.score != b.score ) {
540534 return a.score > b.score ;
541535 }
542- return fixed_wide_state_less (entries[a. begin ]. state_words , entries[b. begin ] .state_words );
536+ return fixed_wide_state_less (a. state_words , b .state_words );
543537}
544538
545539template <size_t Words>
546- size_t trim_compiled_wide_state_groups_by_beam_and_mass (
547- const std::vector<FixedWidePackedMass<Words>>& entries, std::vector<WideStateGroup>* groups ,
548- size_t beam_width, double beam_eps ) {
549- if (groups ->empty ()) {
540+ size_t keep_top_compiled_states (std::vector<FixedWideStateEntry<Words>>* entries,
541+ size_t beam_width, double beam_eps ,
542+ TesseractTrellisRankingMode ranking_mode ) {
543+ if (entries ->empty ()) {
550544 return 0 ;
551545 }
552546
553547 double total_mass = 0.0 ;
554- if (beam_eps > 0.0 ) {
555- for (const auto & group : *groups) {
556- total_mass += group.mass ;
548+ for (auto & entry : *entries) {
549+ const double mass = total_entry_mass (entry);
550+ entry.score = score_mass_and_penalty (mass, entry.penalty , ranking_mode);
551+ if (beam_eps > 0.0 ) {
552+ total_mass += mass;
557553 }
558554 }
559555
560- if (groups ->size () > beam_width) {
561- std::nth_element (groups ->begin (), groups ->begin () + beam_width, groups ->end (),
562- [&entries ](const WideStateGroup & a, const WideStateGroup & b) {
563- return compiled_wide_state_group_score_greater (entries, a, b);
556+ if (entries ->size () > beam_width) {
557+ std::nth_element (entries ->begin (), entries ->begin () + beam_width, entries ->end (),
558+ [](const FixedWideStateEntry<Words> & a, const FixedWideStateEntry<Words> & b) {
559+ return compiled_state_score_greater ( a, b);
564560 });
565- groups ->resize (beam_width);
561+ entries ->resize (beam_width);
566562 } else if (beam_eps <= 0.0 ) {
567- return groups ->size ();
563+ return entries ->size ();
568564 }
569565
570566 if (beam_eps <= 0.0 || total_mass <= 0.0 ) {
571- return groups ->size ();
567+ return entries ->size ();
572568 }
573569
574- std::sort (groups ->begin (), groups ->end (),
575- [&entries ](const WideStateGroup & a, const WideStateGroup & b) {
576- return compiled_wide_state_group_score_greater (entries, a, b);
570+ std::sort (entries ->begin (), entries ->end (),
571+ [](const FixedWideStateEntry<Words> & a, const FixedWideStateEntry<Words> & b) {
572+ return compiled_state_score_greater ( a, b);
577573 });
578574 const double retained_target_mass = total_mass * (1.0 - beam_eps);
579575 double retained_mass = 0.0 ;
580576 size_t keep_count = 0 ;
581- while (keep_count < groups ->size ()) {
582- retained_mass += (*groups )[keep_count]. mass ;
577+ while (keep_count < entries ->size ()) {
578+ retained_mass += total_entry_mass ((*entries )[keep_count]) ;
583579 ++keep_count;
584580 if (retained_mass >= retained_target_mass) {
585581 break ;
586582 }
587583 }
588- groups->resize (keep_count);
589- std::sort (groups->begin (), groups->end (),
590- [](const WideStateGroup& a, const WideStateGroup& b) { return a.begin < b.begin ; });
591- return groups->size ();
592- }
593-
594- template <size_t Words>
595- std::vector<WideStateGroup> collect_compiled_wide_state_groups (
596- const std::vector<FixedWidePackedMass<Words>>& entries,
597- TesseractTrellisRankingMode ranking_mode) {
598- std::vector<WideStateGroup> groups;
599- if (entries.empty ()) {
600- return groups;
601- }
602- groups.reserve (entries.size ());
603- size_t begin = 0 ;
604- while (begin < entries.size ()) {
605- double mass = 0.0 ;
606- size_t end = begin;
607- while (end < entries.size () && entries[end].state_words == entries[begin].state_words ) {
608- mass += entries[end].mass ;
609- ++end;
610- }
611- groups.push_back (
612- {mass, score_mass_and_penalty (mass, entries[begin].penalty , ranking_mode), begin, end});
613- begin = end;
614- }
615- return groups;
616- }
617-
618- template <size_t Words>
619- size_t keep_top_compiled_states (std::vector<FixedWidePackedMass<Words>>* entries,
620- size_t beam_width, double beam_eps,
621- TesseractTrellisRankingMode ranking_mode) {
622- if (entries->empty ()) {
623- return 0 ;
624- }
625- auto groups = collect_compiled_wide_state_groups (*entries, ranking_mode);
626- const size_t kept_group_count =
627- trim_compiled_wide_state_groups_by_beam_and_mass (*entries, &groups, beam_width, beam_eps);
628-
629- std::vector<FixedWidePackedMass<Words>> kept;
630- size_t kept_entries = 0 ;
631- for (const auto & group : groups) {
632- kept_entries += group.end - group.begin ;
633- }
634- kept.reserve (kept_entries);
635- for (const auto & group : groups) {
636- for (size_t k = group.begin ; k < group.end ; ++k) {
637- kept.push_back (std::move ((*entries)[k]));
638- }
639- }
640- *entries = std::move (kept);
641- return kept_group_count;
584+ entries->resize (keep_count);
585+ return keep_count;
642586}
643587
644588template <size_t Words>
@@ -655,7 +599,10 @@ std::vector<CompiledWideLayerTemplate<Words>> compile_wide_layers(
655599 CompiledWideLayerTemplate<Words> compiled;
656600 compiled.q = layer.q ;
657601 compiled.p = layer.p ;
658- compiled.obs_mask = layer.obs_mask ;
602+ if (layer.obs_mask > 1 ) {
603+ throw std::invalid_argument (" tesseract_trellis currently supports at most 1 observable" );
604+ }
605+ compiled.toggles_observable = layer.obs_mask != 0 ;
659606
660607 std::array<uint64_t , Words> surviving_masks{};
661608 for (uint32_t current_local : layer.surviving_local_indices ) {
@@ -738,11 +685,11 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase {
738685 actual_detector_words);
739686 }
740687
741- std::vector<FixedWidePackedMass <Words>> beam_entries;
742- std::vector<FixedWidePackedMass <Words>> next_entries;
688+ std::vector<FixedWideStateEntry <Words>> beam_entries;
689+ std::vector<FixedWideStateEntry <Words>> next_entries;
743690 beam_entries.reserve (decoder->config .beam_width * 2 + 2 );
744691 next_entries.reserve (decoder->config .beam_width * 4 + 4 );
745- beam_entries.push_back ({{}, 0 , 1 .0 , initial_penalty});
692+ beam_entries.push_back ({{}, 1. 0 , 0 .0 , initial_penalty});
746693 decoder->max_beam_size_seen = 1 ;
747694
748695 const bool compute_penalties =
@@ -776,16 +723,28 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase {
776723 FixedWideStateWords<Words> projected_toggled = projected_state;
777724 xor_compiled_wide_state (&projected_toggled, layer.projected_fault_mask_words );
778725 next_entries.push_back (
779- {std::move (projected_state), item.obs_mask , item.mass * layer.q , update.absent_penalty });
780- next_entries.push_back ({std::move (projected_toggled), item.obs_mask ^ layer.obs_mask ,
781- item.mass * layer.p , update.present_penalty });
726+ {std::move (projected_state), item.mass0 * layer.q , item.mass1 * layer.q ,
727+ update.absent_penalty });
728+ if (layer.toggles_observable ) {
729+ next_entries.push_back ({std::move (projected_toggled), item.mass1 * layer.p ,
730+ item.mass0 * layer.p , update.present_penalty });
731+ } else {
732+ next_entries.push_back ({std::move (projected_toggled), item.mass0 * layer.p ,
733+ item.mass1 * layer.p , update.present_penalty });
734+ }
782735 } else if (keep_absent) {
783736 next_entries.push_back (
784- {std::move (projected_state), item.obs_mask , item.mass * layer.q , update.absent_penalty });
737+ {std::move (projected_state), item.mass0 * layer.q , item.mass1 * layer.q ,
738+ update.absent_penalty });
785739 } else if (keep_present) {
786740 xor_compiled_wide_state (&projected_state, layer.projected_fault_mask_words );
787- next_entries.push_back ({std::move (projected_state), item.obs_mask ^ layer.obs_mask ,
788- item.mass * layer.p , update.present_penalty });
741+ if (layer.toggles_observable ) {
742+ next_entries.push_back ({std::move (projected_state), item.mass1 * layer.p ,
743+ item.mass0 * layer.p , update.present_penalty });
744+ } else {
745+ next_entries.push_back ({std::move (projected_state), item.mass0 * layer.p ,
746+ item.mass1 * layer.p , update.present_penalty });
747+ }
789748 }
790749 }
791750 auto t1 = std::chrono::high_resolution_clock::now ();
@@ -820,11 +779,8 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase {
820779 if (!fixed_wide_state_zero (item.state_words )) {
821780 continue ;
822781 }
823- if (item.obs_mask == 0 ) {
824- decoder->total_mass_obs0 += item.mass ;
825- } else if (item.obs_mask == 1 ) {
826- decoder->total_mass_obs1 += item.mass ;
827- }
782+ decoder->total_mass_obs0 += item.mass0 ;
783+ decoder->total_mass_obs1 += item.mass1 ;
828784 }
829785 if (decoder->total_mass_obs0 == 0.0 && decoder->total_mass_obs1 == 0.0 ) {
830786 decoder->low_confidence_flag = true ;
@@ -903,6 +859,9 @@ TesseractTrellisDecoder::TesseractTrellisDecoder(TesseractTrellisConfig config_)
903859 errors = get_errors_from_dem (config.dem .flattened ());
904860 num_detectors = config.dem .count_detectors ();
905861 num_observables = config.dem .count_observables ();
862+ if (num_observables > 1 ) {
863+ throw std::invalid_argument (" tesseract_trellis currently supports at most 1 observable" );
864+ }
906865
907866 all_possible_detector_words.assign (num_state_words (num_detectors), 0 );
908867 actual_detector_words_scratch.assign (all_possible_detector_words.size (), 0 );
0 commit comments