@@ -69,6 +69,16 @@ struct FixedWideStateEntry {
6969 double score = -INF;
7070};
7171
72+ template <size_t Words>
73+ struct FixedWidePairBucket {
74+ FixedWideStateWords<Words> key{};
75+ double mass0[2 ]{};
76+ double mass1[2 ]{};
77+ double penalty[2 ]{};
78+ uint8_t used_mask = 0 ;
79+ bool occupied = false ;
80+ };
81+
7282template <size_t Words>
7383struct CompiledWideLayerTemplate {
7484 double q = 0.0 ;
@@ -313,6 +323,13 @@ TESSERACT_ALWAYS_INLINE double total_entry_mass(const FixedWideStateEntry<Words>
313323 return entry.mass0 + entry.mass1 ;
314324}
315325
326+ TESSERACT_ALWAYS_INLINE uint64_t mix_splitmix64 (uint64_t value) {
327+ value += 0x9e3779b97f4a7c15ULL ;
328+ value = (value ^ (value >> 30 )) * 0xbf58476d1ce4e5b9ULL ;
329+ value = (value ^ (value >> 27 )) * 0x94d049bb133111ebULL ;
330+ return value ^ (value >> 31 );
331+ }
332+
316333void reset_kept_state_stats (TesseractTrellisDecoder* decoder) {
317334 decoder->kept_state_sample_count = 0 ;
318335 decoder->kept_state_min = 0 ;
@@ -422,6 +439,72 @@ void xor_compiled_wide_state(FixedWideStateWords<Words>* state_words,
422439 }
423440}
424441
442+ template <size_t Words>
443+ TESSERACT_ALWAYS_INLINE uint64_t hash_fixed_wide_state (const FixedWideStateWords<Words>& state_words) {
444+ uint64_t hash = 0x123456789abcdef0ULL ;
445+ for (size_t k = 0 ; k < Words; ++k) {
446+ hash ^= mix_splitmix64 (state_words[k] + 0x9e3779b97f4a7c15ULL * (k + 1 ));
447+ hash = std::rotl (hash, 21 );
448+ }
449+ return hash;
450+ }
451+
452+ template <size_t Words>
453+ void ensure_pair_bucket_capacity (std::vector<FixedWidePairBucket<Words>>* buckets,
454+ size_t num_parents) {
455+ const size_t required = std::bit_ceil (std::max<size_t >(16 , num_parents * 2 ));
456+ if (buckets->size () < required) {
457+ buckets->resize (required);
458+ }
459+ }
460+
461+ template <size_t Words>
462+ void clear_pair_buckets (std::vector<FixedWidePairBucket<Words>>* buckets,
463+ std::vector<size_t >* used_bucket_indices) {
464+ for (size_t index : *used_bucket_indices) {
465+ (*buckets)[index].occupied = false ;
466+ (*buckets)[index].used_mask = 0 ;
467+ }
468+ used_bucket_indices->clear ();
469+ }
470+
471+ template <size_t Words>
472+ TESSERACT_ALWAYS_INLINE size_t find_or_insert_pair_bucket (
473+ std::vector<FixedWidePairBucket<Words>>* buckets, std::vector<size_t >* used_bucket_indices,
474+ const FixedWideStateWords<Words>& key) {
475+ const size_t mask = buckets->size () - 1 ;
476+ size_t index = hash_fixed_wide_state (key) & mask;
477+ while ((*buckets)[index].occupied ) {
478+ if ((*buckets)[index].key == key) {
479+ return index;
480+ }
481+ index = (index + 1 ) & mask;
482+ }
483+
484+ auto & bucket = (*buckets)[index];
485+ bucket.occupied = true ;
486+ bucket.key = key;
487+ bucket.used_mask = 0 ;
488+ used_bucket_indices->push_back (index);
489+ return index;
490+ }
491+
492+ template <size_t Words>
493+ TESSERACT_ALWAYS_INLINE void accumulate_pair_bucket_slot (FixedWidePairBucket<Words>* bucket,
494+ uint8_t slot, double mass0, double mass1,
495+ double penalty) {
496+ const uint8_t bit = (uint8_t )(1u << slot);
497+ if ((bucket->used_mask & bit) == 0 ) {
498+ bucket->mass0 [slot] = mass0;
499+ bucket->mass1 [slot] = mass1;
500+ bucket->penalty [slot] = penalty;
501+ bucket->used_mask |= bit;
502+ } else {
503+ bucket->mass0 [slot] += mass0;
504+ bucket->mass1 [slot] += mass1;
505+ }
506+ }
507+
425508template <size_t Words>
426509FixedWideStateWords<Words> project_compiled_wide_state (
427510 const FixedWideStateWords<Words>& state_words, const CompiledWideLayerTemplate<Words>& layer) {
@@ -687,6 +770,8 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase {
687770
688771 std::vector<FixedWideStateEntry<Words>> beam_entries;
689772 std::vector<FixedWideStateEntry<Words>> next_entries;
773+ std::vector<FixedWidePairBucket<Words>> pair_buckets;
774+ std::vector<size_t > used_bucket_indices;
690775 beam_entries.reserve (decoder->config .beam_width * 2 + 2 );
691776 next_entries.reserve (decoder->config .beam_width * 4 + 4 );
692777 beam_entries.push_back ({{}, 1.0 , 0.0 , initial_penalty});
@@ -697,9 +782,10 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase {
697782 for (size_t layer_index = 0 ; layer_index < layers.size (); ++layer_index) {
698783 const auto & layer = layers[layer_index];
699784
785+ ensure_pair_bucket_capacity (&pair_buckets, beam_entries.size ());
786+ clear_pair_buckets (&pair_buckets, &used_bucket_indices);
787+
700788 auto t0 = std::chrono::high_resolution_clock::now ();
701- next_entries.clear ();
702- next_entries.reserve (beam_entries.size () * 2 );
703789
704790 if (decoder->config .verbose ) {
705791 std::cout << " expanding layer " << layer_index << " / " << (layers.size () - 1 )
@@ -717,43 +803,52 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase {
717803
718804 FixedWideStateWords<Words> projected_state =
719805 project_compiled_wide_state (item.state_words , layer);
806+ FixedWideStateWords<Words> projected_toggled = projected_state;
807+ xor_compiled_wide_state (&projected_toggled, layer.projected_fault_mask_words );
808+ const bool projected_is_key = !fixed_wide_state_less (projected_toggled, projected_state);
809+ const auto & bucket_key = projected_is_key ? projected_state : projected_toggled;
810+ const uint8_t absent_slot = projected_is_key ? 0 : 1 ;
811+ const uint8_t present_slot = projected_toggled == bucket_key ? 0 : 1 ;
812+ const size_t bucket_index =
813+ find_or_insert_pair_bucket (&pair_buckets, &used_bucket_indices, bucket_key);
814+ auto & bucket = pair_buckets[bucket_index];
720815 const bool keep_absent = update.absent_valid && layer.q != 0.0 ;
721816 const bool keep_present = update.present_valid && layer.p != 0.0 ;
722- if (keep_absent && keep_present) {
723- FixedWideStateWords<Words> projected_toggled = projected_state;
724- xor_compiled_wide_state (&projected_toggled, layer.projected_fault_mask_words );
725- next_entries.push_back (
726- {std::move (projected_state), item.mass0 * layer.q , item.mass1 * layer.q ,
727- update.absent_penalty });
728- if (layer.toggles_observable ) {
729- next_entries.push_back ({std::move (projected_toggled), item.mass1 * layer.p ,
730- item.mass0 * layer.p , update.present_penalty });
731- } else {
732- next_entries.push_back ({std::move (projected_toggled), item.mass0 * layer.p ,
733- item.mass1 * layer.p , update.present_penalty });
734- }
735- } else if (keep_absent) {
736- next_entries.push_back (
737- {std::move (projected_state), item.mass0 * layer.q , item.mass1 * layer.q ,
738- update.absent_penalty });
739- } else if (keep_present) {
740- xor_compiled_wide_state (&projected_state, layer.projected_fault_mask_words );
817+
818+ if (keep_absent) {
819+ accumulate_pair_bucket_slot (&bucket, absent_slot, item.mass0 * layer.q ,
820+ item.mass1 * layer.q , update.absent_penalty );
821+ }
822+ if (keep_present) {
741823 if (layer.toggles_observable ) {
742- next_entries. push_back ({ std::move (projected_state) , item.mass1 * layer.p ,
743- item.mass0 * layer.p , update.present_penalty } );
824+ accumulate_pair_bucket_slot (&bucket, present_slot , item.mass1 * layer.p ,
825+ item.mass0 * layer.p , update.present_penalty );
744826 } else {
745- next_entries. push_back ({ std::move (projected_state) , item.mass0 * layer.p ,
746- item.mass1 * layer.p , update.present_penalty } );
827+ accumulate_pair_bucket_slot (&bucket, present_slot , item.mass0 * layer.p ,
828+ item.mass1 * layer.p , update.present_penalty );
747829 }
748830 }
749831 }
750832 auto t1 = std::chrono::high_resolution_clock::now ();
751833 decoder->time_expand_seconds +=
752834 std::chrono::duration_cast<std::chrono::microseconds>(t1 - t0).count () / 1e6 ;
753835
754- beam_entries.swap (next_entries);
755836 auto t2a = std::chrono::high_resolution_clock::now ();
756- merge_equal_compiled_keys_inplace (&beam_entries);
837+ next_entries.clear ();
838+ next_entries.reserve (used_bucket_indices.size () * 2 );
839+ for (size_t index : used_bucket_indices) {
840+ auto & bucket = pair_buckets[index];
841+ if ((bucket.used_mask & 1u ) != 0 ) {
842+ next_entries.push_back ({bucket.key , bucket.mass0 [0 ], bucket.mass1 [0 ], bucket.penalty [0 ]});
843+ }
844+ if ((bucket.used_mask & 2u ) != 0 ) {
845+ auto other_state = bucket.key ;
846+ xor_compiled_wide_state (&other_state, layer.projected_fault_mask_words );
847+ next_entries.push_back (
848+ {std::move (other_state), bucket.mass0 [1 ], bucket.mass1 [1 ], bucket.penalty [1 ]});
849+ }
850+ }
851+ beam_entries.swap (next_entries);
757852 auto t2 = std::chrono::high_resolution_clock::now ();
758853 decoder->time_collapse_seconds +=
759854 std::chrono::duration_cast<std::chrono::microseconds>(t2 - t2a).count () / 1e6 ;
0 commit comments