Skip to content

Commit f016a98

Browse files
committed
1.5x speedup on trellis by accumulating states directly into pair buckets
1 parent 83ade9d commit f016a98

1 file changed

Lines changed: 122 additions & 27 deletions

File tree

src/tesseract_trellis.cc

Lines changed: 122 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,16 @@ struct FixedWideStateEntry {
6969
double score = -INF;
7070
};
7171

72+
template <size_t Words>
73+
struct FixedWidePairBucket {
74+
FixedWideStateWords<Words> key{};
75+
double mass0[2]{};
76+
double mass1[2]{};
77+
double penalty[2]{};
78+
uint8_t used_mask = 0;
79+
bool occupied = false;
80+
};
81+
7282
template <size_t Words>
7383
struct CompiledWideLayerTemplate {
7484
double q = 0.0;
@@ -313,6 +323,13 @@ TESSERACT_ALWAYS_INLINE double total_entry_mass(const FixedWideStateEntry<Words>
313323
return entry.mass0 + entry.mass1;
314324
}
315325

326+
TESSERACT_ALWAYS_INLINE uint64_t mix_splitmix64(uint64_t value) {
327+
value += 0x9e3779b97f4a7c15ULL;
328+
value = (value ^ (value >> 30)) * 0xbf58476d1ce4e5b9ULL;
329+
value = (value ^ (value >> 27)) * 0x94d049bb133111ebULL;
330+
return value ^ (value >> 31);
331+
}
332+
316333
void reset_kept_state_stats(TesseractTrellisDecoder* decoder) {
317334
decoder->kept_state_sample_count = 0;
318335
decoder->kept_state_min = 0;
@@ -422,6 +439,72 @@ void xor_compiled_wide_state(FixedWideStateWords<Words>* state_words,
422439
}
423440
}
424441

442+
template <size_t Words>
443+
TESSERACT_ALWAYS_INLINE uint64_t hash_fixed_wide_state(const FixedWideStateWords<Words>& state_words) {
444+
uint64_t hash = 0x123456789abcdef0ULL;
445+
for (size_t k = 0; k < Words; ++k) {
446+
hash ^= mix_splitmix64(state_words[k] + 0x9e3779b97f4a7c15ULL * (k + 1));
447+
hash = std::rotl(hash, 21);
448+
}
449+
return hash;
450+
}
451+
452+
template <size_t Words>
453+
void ensure_pair_bucket_capacity(std::vector<FixedWidePairBucket<Words>>* buckets,
454+
size_t num_parents) {
455+
const size_t required = std::bit_ceil(std::max<size_t>(16, num_parents * 2));
456+
if (buckets->size() < required) {
457+
buckets->resize(required);
458+
}
459+
}
460+
461+
template <size_t Words>
462+
void clear_pair_buckets(std::vector<FixedWidePairBucket<Words>>* buckets,
463+
std::vector<size_t>* used_bucket_indices) {
464+
for (size_t index : *used_bucket_indices) {
465+
(*buckets)[index].occupied = false;
466+
(*buckets)[index].used_mask = 0;
467+
}
468+
used_bucket_indices->clear();
469+
}
470+
471+
template <size_t Words>
472+
TESSERACT_ALWAYS_INLINE size_t find_or_insert_pair_bucket(
473+
std::vector<FixedWidePairBucket<Words>>* buckets, std::vector<size_t>* used_bucket_indices,
474+
const FixedWideStateWords<Words>& key) {
475+
const size_t mask = buckets->size() - 1;
476+
size_t index = hash_fixed_wide_state(key) & mask;
477+
while ((*buckets)[index].occupied) {
478+
if ((*buckets)[index].key == key) {
479+
return index;
480+
}
481+
index = (index + 1) & mask;
482+
}
483+
484+
auto& bucket = (*buckets)[index];
485+
bucket.occupied = true;
486+
bucket.key = key;
487+
bucket.used_mask = 0;
488+
used_bucket_indices->push_back(index);
489+
return index;
490+
}
491+
492+
template <size_t Words>
493+
TESSERACT_ALWAYS_INLINE void accumulate_pair_bucket_slot(FixedWidePairBucket<Words>* bucket,
494+
uint8_t slot, double mass0, double mass1,
495+
double penalty) {
496+
const uint8_t bit = (uint8_t)(1u << slot);
497+
if ((bucket->used_mask & bit) == 0) {
498+
bucket->mass0[slot] = mass0;
499+
bucket->mass1[slot] = mass1;
500+
bucket->penalty[slot] = penalty;
501+
bucket->used_mask |= bit;
502+
} else {
503+
bucket->mass0[slot] += mass0;
504+
bucket->mass1[slot] += mass1;
505+
}
506+
}
507+
425508
template <size_t Words>
426509
FixedWideStateWords<Words> project_compiled_wide_state(
427510
const FixedWideStateWords<Words>& state_words, const CompiledWideLayerTemplate<Words>& layer) {
@@ -687,6 +770,8 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase {
687770

688771
std::vector<FixedWideStateEntry<Words>> beam_entries;
689772
std::vector<FixedWideStateEntry<Words>> next_entries;
773+
std::vector<FixedWidePairBucket<Words>> pair_buckets;
774+
std::vector<size_t> used_bucket_indices;
690775
beam_entries.reserve(decoder->config.beam_width * 2 + 2);
691776
next_entries.reserve(decoder->config.beam_width * 4 + 4);
692777
beam_entries.push_back({{}, 1.0, 0.0, initial_penalty});
@@ -697,9 +782,10 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase {
697782
for (size_t layer_index = 0; layer_index < layers.size(); ++layer_index) {
698783
const auto& layer = layers[layer_index];
699784

785+
ensure_pair_bucket_capacity(&pair_buckets, beam_entries.size());
786+
clear_pair_buckets(&pair_buckets, &used_bucket_indices);
787+
700788
auto t0 = std::chrono::high_resolution_clock::now();
701-
next_entries.clear();
702-
next_entries.reserve(beam_entries.size() * 2);
703789

704790
if (decoder->config.verbose) {
705791
std::cout << "expanding layer " << layer_index << " / " << (layers.size() - 1)
@@ -717,43 +803,52 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase {
717803

718804
FixedWideStateWords<Words> projected_state =
719805
project_compiled_wide_state(item.state_words, layer);
806+
FixedWideStateWords<Words> projected_toggled = projected_state;
807+
xor_compiled_wide_state(&projected_toggled, layer.projected_fault_mask_words);
808+
const bool projected_is_key = !fixed_wide_state_less(projected_toggled, projected_state);
809+
const auto& bucket_key = projected_is_key ? projected_state : projected_toggled;
810+
const uint8_t absent_slot = projected_is_key ? 0 : 1;
811+
const uint8_t present_slot = projected_toggled == bucket_key ? 0 : 1;
812+
const size_t bucket_index =
813+
find_or_insert_pair_bucket(&pair_buckets, &used_bucket_indices, bucket_key);
814+
auto& bucket = pair_buckets[bucket_index];
720815
const bool keep_absent = update.absent_valid && layer.q != 0.0;
721816
const bool keep_present = update.present_valid && layer.p != 0.0;
722-
if (keep_absent && keep_present) {
723-
FixedWideStateWords<Words> projected_toggled = projected_state;
724-
xor_compiled_wide_state(&projected_toggled, layer.projected_fault_mask_words);
725-
next_entries.push_back(
726-
{std::move(projected_state), item.mass0 * layer.q, item.mass1 * layer.q,
727-
update.absent_penalty});
728-
if (layer.toggles_observable) {
729-
next_entries.push_back({std::move(projected_toggled), item.mass1 * layer.p,
730-
item.mass0 * layer.p, update.present_penalty});
731-
} else {
732-
next_entries.push_back({std::move(projected_toggled), item.mass0 * layer.p,
733-
item.mass1 * layer.p, update.present_penalty});
734-
}
735-
} else if (keep_absent) {
736-
next_entries.push_back(
737-
{std::move(projected_state), item.mass0 * layer.q, item.mass1 * layer.q,
738-
update.absent_penalty});
739-
} else if (keep_present) {
740-
xor_compiled_wide_state(&projected_state, layer.projected_fault_mask_words);
817+
818+
if (keep_absent) {
819+
accumulate_pair_bucket_slot(&bucket, absent_slot, item.mass0 * layer.q,
820+
item.mass1 * layer.q, update.absent_penalty);
821+
}
822+
if (keep_present) {
741823
if (layer.toggles_observable) {
742-
next_entries.push_back({std::move(projected_state), item.mass1 * layer.p,
743-
item.mass0 * layer.p, update.present_penalty});
824+
accumulate_pair_bucket_slot(&bucket, present_slot, item.mass1 * layer.p,
825+
item.mass0 * layer.p, update.present_penalty);
744826
} else {
745-
next_entries.push_back({std::move(projected_state), item.mass0 * layer.p,
746-
item.mass1 * layer.p, update.present_penalty});
827+
accumulate_pair_bucket_slot(&bucket, present_slot, item.mass0 * layer.p,
828+
item.mass1 * layer.p, update.present_penalty);
747829
}
748830
}
749831
}
750832
auto t1 = std::chrono::high_resolution_clock::now();
751833
decoder->time_expand_seconds +=
752834
std::chrono::duration_cast<std::chrono::microseconds>(t1 - t0).count() / 1e6;
753835

754-
beam_entries.swap(next_entries);
755836
auto t2a = std::chrono::high_resolution_clock::now();
756-
merge_equal_compiled_keys_inplace(&beam_entries);
837+
next_entries.clear();
838+
next_entries.reserve(used_bucket_indices.size() * 2);
839+
for (size_t index : used_bucket_indices) {
840+
auto& bucket = pair_buckets[index];
841+
if ((bucket.used_mask & 1u) != 0) {
842+
next_entries.push_back({bucket.key, bucket.mass0[0], bucket.mass1[0], bucket.penalty[0]});
843+
}
844+
if ((bucket.used_mask & 2u) != 0) {
845+
auto other_state = bucket.key;
846+
xor_compiled_wide_state(&other_state, layer.projected_fault_mask_words);
847+
next_entries.push_back(
848+
{std::move(other_state), bucket.mass0[1], bucket.mass1[1], bucket.penalty[1]});
849+
}
850+
}
851+
beam_entries.swap(next_entries);
757852
auto t2 = std::chrono::high_resolution_clock::now();
758853
decoder->time_collapse_seconds +=
759854
std::chrono::duration_cast<std::chrono::microseconds>(t2 - t2a).count() / 1e6;

0 commit comments

Comments
 (0)