Skip to content

Commit 53ca8da

Browse files
committed
some optimizations to tesseract trellis
1 parent da4a0a3 commit 53ca8da

2 files changed

Lines changed: 20 additions & 82 deletions

File tree

src/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,6 @@ cc_binary(
194194
],
195195
)
196196

197-
198197
cc_binary(
199198
name = "tesseract_ftl",
200199
srcs = ["tesseract_ftl_main.cc"],

src/tesseract_trellis.cc

Lines changed: 20 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -319,66 +319,6 @@ double compute_penalty_from_scratch(uint64_t mismatch_mask,
319319
return total;
320320
}
321321

322-
double advance_penalty_row(double current_penalty, uint64_t current_mismatch,
323-
const TesseractTrellisDetcostTransition& transition) {
324-
if (current_penalty == INF) {
325-
return INF;
326-
}
327-
double total = current_penalty;
328-
for (size_t k = 0; k < transition.fault_local_indices.size(); ++k) {
329-
const uint64_t local_bit = uint64_t{1} << transition.fault_local_indices[k];
330-
if ((current_mismatch & local_bit) == 0) {
331-
continue;
332-
}
333-
double current_cost = transition.current_costs[k];
334-
double next_cost = transition.next_costs[k];
335-
if (next_cost == INF) {
336-
return INF;
337-
}
338-
total += next_cost - current_cost;
339-
}
340-
return total;
341-
}
342-
343-
double adjust_penalty_for_branch(double parent_penalty_next_row, uint64_t base_state,
344-
uint64_t current_target_bits, uint64_t next_target_bits,
345-
bool present_branch, uint64_t projected_state,
346-
const TesseractTrellisSmallLayerTemplate& layer) {
347-
if (parent_penalty_next_row == INF) {
348-
return compute_penalty_from_scratch(projected_state ^ next_target_bits, layer.next_frontier_costs);
349-
}
350-
351-
double total = parent_penalty_next_row;
352-
for (size_t k = 0; k < layer.detcost_transition.fault_local_indices.size(); ++k) {
353-
uint8_t local = layer.detcost_transition.fault_local_indices[k];
354-
int8_t next_local = layer.detcost_transition.next_local_indices[k];
355-
if (next_local < 0) {
356-
continue;
357-
}
358-
359-
const uint64_t state_bit =
360-
local < layer.previous_width ? ((base_state >> local) & 1ULL) : 0ULL;
361-
const uint64_t prev_mismatch =
362-
local < layer.previous_width ? (state_bit ^ ((current_target_bits >> local) & 1ULL)) : 0ULL;
363-
const uint64_t child_bit = state_bit ^ (present_branch ? 1ULL : 0ULL);
364-
const uint64_t child_mismatch = child_bit ^ ((next_target_bits >> next_local) & 1ULL);
365-
if (prev_mismatch == child_mismatch) {
366-
continue;
367-
}
368-
369-
double next_cost = layer.detcost_transition.next_costs[k];
370-
if (child_mismatch) {
371-
if (next_cost == INF) {
372-
return INF;
373-
}
374-
total += next_cost;
375-
} else {
376-
total -= next_cost;
377-
}
378-
}
379-
return total;
380-
}
381-
382322
void build_future_detcost_transitions(const std::vector<Fault>& faults, size_t num_detectors,
383323
std::vector<TesseractTrellisSmallLayerTemplate>* layers,
384324
std::vector<double>* initial_future_detcost) {
@@ -651,16 +591,24 @@ void TesseractTrellisDecoder::decode_shot(const std::vector<uint64_t>& detection
651591
std::vector<PackedMass> beam_entries;
652592
double initial_penalty = 0.0;
653593
if (config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked) {
594+
std::vector<double> initial_frontier_costs;
595+
if (!small_layer_templates.empty()) {
596+
const auto& first_layer = small_layer_templates.front();
597+
initial_frontier_costs.resize(first_layer.current_active_detectors.size(), INF);
598+
for (size_t local = 0; local < first_layer.current_active_detectors.size(); ++local) {
599+
initial_frontier_costs[local] =
600+
initial_future_detcost[(size_t)first_layer.current_active_detectors[local]];
601+
}
602+
}
654603
initial_penalty = compute_penalty_from_scratch(
655604
current_target_bits_per_layer.empty() ? 0 : current_target_bits_per_layer.front(),
656-
initial_future_detcost);
605+
initial_frontier_costs);
657606
}
658607
beam_entries.push_back({pack_small_key(0, 0), 1.0, initial_penalty});
659608
max_beam_size_seen = 1;
660609

661610
for (size_t layer_index = 0; layer_index < small_layer_templates.size(); ++layer_index) {
662611
const auto& layer = small_layer_templates[layer_index];
663-
const uint64_t current_target_bits = current_target_bits_per_layer[layer_index];
664612
const uint64_t next_target_bits = next_target_bits_per_layer[layer_index];
665613
const uint64_t expected_retiring_bits = expected_retiring_bits_per_layer[layer_index];
666614
auto t0 = std::chrono::high_resolution_clock::now();
@@ -670,33 +618,17 @@ void TesseractTrellisDecoder::decode_shot(const std::vector<uint64_t>& detection
670618
++num_states_expanded;
671619
const uint64_t base_state = unpack_small_state(item.key);
672620
const uint64_t base_obs = unpack_small_obs(item.key);
673-
const double parent_penalty_next_row =
674-
config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked
675-
? advance_penalty_row(item.penalty, base_state ^ current_target_bits,
676-
layer.detcost_transition)
677-
: 0.0;
678621

679622
if (((base_state ^ expected_retiring_bits) & layer.retiring_mask) == 0) {
680623
uint64_t projected_state = project_small_state(base_state, layer.surviving_local_indices);
681-
double penalty =
682-
config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked
683-
? adjust_penalty_for_branch(parent_penalty_next_row, base_state, current_target_bits,
684-
next_target_bits, false, projected_state, layer)
685-
: 0.0;
686-
next_entries.push_back(
687-
{pack_small_key(projected_state, base_obs), item.mass * layer.q, penalty});
624+
next_entries.push_back({pack_small_key(projected_state, base_obs), item.mass * layer.q, 0.0});
688625
}
689626

690627
uint64_t toggled_state = base_state ^ layer.local_det_mask;
691628
if (((toggled_state ^ expected_retiring_bits) & layer.retiring_mask) == 0) {
692629
uint64_t projected_state = project_small_state(toggled_state, layer.surviving_local_indices);
693-
double penalty =
694-
config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked
695-
? adjust_penalty_for_branch(parent_penalty_next_row, base_state, current_target_bits,
696-
next_target_bits, true, projected_state, layer)
697-
: 0.0;
698-
next_entries.push_back({pack_small_key(projected_state, base_obs ^ layer.obs_flip_bit),
699-
item.mass * layer.p, penalty});
630+
next_entries.push_back(
631+
{pack_small_key(projected_state, base_obs ^ layer.obs_flip_bit), item.mass * layer.p, 0.0});
700632
}
701633
}
702634
auto t1 = std::chrono::high_resolution_clock::now();
@@ -723,6 +655,13 @@ void TesseractTrellisDecoder::decode_shot(const std::vector<uint64_t>& detection
723655
time_collapse_seconds +=
724656
std::chrono::duration_cast<std::chrono::microseconds>(t2 - t2a).count() / 1e6;
725657

658+
if (config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked) {
659+
for (auto& item : beam_entries) {
660+
item.penalty = compute_penalty_from_scratch(unpack_small_state(item.key) ^ next_target_bits,
661+
layer.next_frontier_costs);
662+
}
663+
}
664+
726665
if (config.prune_mode == TesseractTrellisPruneMode::MergedStates) {
727666
keep_top_states(beam_entries, config.beam_width, config.ranking_mode);
728667
} else if (config.prune_mode == TesseractTrellisPruneMode::BranchEntries ||

0 commit comments

Comments
 (0)