@@ -319,66 +319,6 @@ double compute_penalty_from_scratch(uint64_t mismatch_mask,
319319 return total;
320320}
321321
322- double advance_penalty_row (double current_penalty, uint64_t current_mismatch,
323- const TesseractTrellisDetcostTransition& transition) {
324- if (current_penalty == INF) {
325- return INF;
326- }
327- double total = current_penalty;
328- for (size_t k = 0 ; k < transition.fault_local_indices .size (); ++k) {
329- const uint64_t local_bit = uint64_t {1 } << transition.fault_local_indices [k];
330- if ((current_mismatch & local_bit) == 0 ) {
331- continue ;
332- }
333- double current_cost = transition.current_costs [k];
334- double next_cost = transition.next_costs [k];
335- if (next_cost == INF) {
336- return INF;
337- }
338- total += next_cost - current_cost;
339- }
340- return total;
341- }
342-
343- double adjust_penalty_for_branch (double parent_penalty_next_row, uint64_t base_state,
344- uint64_t current_target_bits, uint64_t next_target_bits,
345- bool present_branch, uint64_t projected_state,
346- const TesseractTrellisSmallLayerTemplate& layer) {
347- if (parent_penalty_next_row == INF) {
348- return compute_penalty_from_scratch (projected_state ^ next_target_bits, layer.next_frontier_costs );
349- }
350-
351- double total = parent_penalty_next_row;
352- for (size_t k = 0 ; k < layer.detcost_transition .fault_local_indices .size (); ++k) {
353- uint8_t local = layer.detcost_transition .fault_local_indices [k];
354- int8_t next_local = layer.detcost_transition .next_local_indices [k];
355- if (next_local < 0 ) {
356- continue ;
357- }
358-
359- const uint64_t state_bit =
360- local < layer.previous_width ? ((base_state >> local) & 1ULL ) : 0ULL ;
361- const uint64_t prev_mismatch =
362- local < layer.previous_width ? (state_bit ^ ((current_target_bits >> local) & 1ULL )) : 0ULL ;
363- const uint64_t child_bit = state_bit ^ (present_branch ? 1ULL : 0ULL );
364- const uint64_t child_mismatch = child_bit ^ ((next_target_bits >> next_local) & 1ULL );
365- if (prev_mismatch == child_mismatch) {
366- continue ;
367- }
368-
369- double next_cost = layer.detcost_transition .next_costs [k];
370- if (child_mismatch) {
371- if (next_cost == INF) {
372- return INF;
373- }
374- total += next_cost;
375- } else {
376- total -= next_cost;
377- }
378- }
379- return total;
380- }
381-
382322void build_future_detcost_transitions (const std::vector<Fault>& faults, size_t num_detectors,
383323 std::vector<TesseractTrellisSmallLayerTemplate>* layers,
384324 std::vector<double >* initial_future_detcost) {
@@ -651,16 +591,24 @@ void TesseractTrellisDecoder::decode_shot(const std::vector<uint64_t>& detection
651591 std::vector<PackedMass> beam_entries;
652592 double initial_penalty = 0.0 ;
653593 if (config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked) {
594+ std::vector<double > initial_frontier_costs;
595+ if (!small_layer_templates.empty ()) {
596+ const auto & first_layer = small_layer_templates.front ();
597+ initial_frontier_costs.resize (first_layer.current_active_detectors .size (), INF);
598+ for (size_t local = 0 ; local < first_layer.current_active_detectors .size (); ++local) {
599+ initial_frontier_costs[local] =
600+ initial_future_detcost[(size_t )first_layer.current_active_detectors [local]];
601+ }
602+ }
654603 initial_penalty = compute_penalty_from_scratch (
655604 current_target_bits_per_layer.empty () ? 0 : current_target_bits_per_layer.front (),
656- initial_future_detcost );
605+ initial_frontier_costs );
657606 }
658607 beam_entries.push_back ({pack_small_key (0 , 0 ), 1.0 , initial_penalty});
659608 max_beam_size_seen = 1 ;
660609
661610 for (size_t layer_index = 0 ; layer_index < small_layer_templates.size (); ++layer_index) {
662611 const auto & layer = small_layer_templates[layer_index];
663- const uint64_t current_target_bits = current_target_bits_per_layer[layer_index];
664612 const uint64_t next_target_bits = next_target_bits_per_layer[layer_index];
665613 const uint64_t expected_retiring_bits = expected_retiring_bits_per_layer[layer_index];
666614 auto t0 = std::chrono::high_resolution_clock::now ();
@@ -670,33 +618,17 @@ void TesseractTrellisDecoder::decode_shot(const std::vector<uint64_t>& detection
670618 ++num_states_expanded;
671619 const uint64_t base_state = unpack_small_state (item.key );
672620 const uint64_t base_obs = unpack_small_obs (item.key );
673- const double parent_penalty_next_row =
674- config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked
675- ? advance_penalty_row (item.penalty , base_state ^ current_target_bits,
676- layer.detcost_transition )
677- : 0.0 ;
678621
679622 if (((base_state ^ expected_retiring_bits) & layer.retiring_mask ) == 0 ) {
680623 uint64_t projected_state = project_small_state (base_state, layer.surviving_local_indices );
681- double penalty =
682- config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked
683- ? adjust_penalty_for_branch (parent_penalty_next_row, base_state, current_target_bits,
684- next_target_bits, false , projected_state, layer)
685- : 0.0 ;
686- next_entries.push_back (
687- {pack_small_key (projected_state, base_obs), item.mass * layer.q , penalty});
624+ next_entries.push_back ({pack_small_key (projected_state, base_obs), item.mass * layer.q , 0.0 });
688625 }
689626
690627 uint64_t toggled_state = base_state ^ layer.local_det_mask ;
691628 if (((toggled_state ^ expected_retiring_bits) & layer.retiring_mask ) == 0 ) {
692629 uint64_t projected_state = project_small_state (toggled_state, layer.surviving_local_indices );
693- double penalty =
694- config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked
695- ? adjust_penalty_for_branch (parent_penalty_next_row, base_state, current_target_bits,
696- next_target_bits, true , projected_state, layer)
697- : 0.0 ;
698- next_entries.push_back ({pack_small_key (projected_state, base_obs ^ layer.obs_flip_bit ),
699- item.mass * layer.p , penalty});
630+ next_entries.push_back (
631+ {pack_small_key (projected_state, base_obs ^ layer.obs_flip_bit ), item.mass * layer.p , 0.0 });
700632 }
701633 }
702634 auto t1 = std::chrono::high_resolution_clock::now ();
@@ -723,6 +655,13 @@ void TesseractTrellisDecoder::decode_shot(const std::vector<uint64_t>& detection
723655 time_collapse_seconds +=
724656 std::chrono::duration_cast<std::chrono::microseconds>(t2 - t2a).count () / 1e6 ;
725657
658+ if (config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked) {
659+ for (auto & item : beam_entries) {
660+ item.penalty = compute_penalty_from_scratch (unpack_small_state (item.key ) ^ next_target_bits,
661+ layer.next_frontier_costs );
662+ }
663+ }
664+
726665 if (config.prune_mode == TesseractTrellisPruneMode::MergedStates) {
727666 keep_top_states (beam_entries, config.beam_width , config.ranking_mode );
728667 } else if (config.prune_mode == TesseractTrellisPruneMode::BranchEntries ||
0 commit comments