diff --git a/src/simplex_main.cc b/src/simplex_main.cc index e8da07d..7939a91 100644 --- a/src/simplex_main.cc +++ b/src/simplex_main.cc @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -107,6 +108,9 @@ struct Args { "Cannot load observable flips without a corresponding detection " "event data file."); } + if (num_threads == 0) { + throw std::invalid_argument("--threads must be at least 1."); + } if (num_threads > 1000) { throw std::invalid_argument( "There is a maximum limit of 1000 threads imposed to avoid " @@ -367,7 +371,8 @@ int main(int argc, char* argv[]) { program.add_argument("--threads") .help("Number of decoder threads to use") .metavar("N") - .default_value(size_t(std::thread::hardware_concurrency())) + .default_value(size_t( + std::thread::hardware_concurrency() == 0 ? 1 : std::thread::hardware_concurrency())) .store_into(args.num_threads); program.add_argument("--parallelize-ilp") .help( @@ -416,97 +421,64 @@ int main(int argc, char* argv[]) { std::vector shots; std::unique_ptr writer; args.extract(config, shots, writer); - std::atomic next_unclaimed_shot; - std::vector> finished(shots.size()); std::vector obs_predicted(shots.size()); std::vector cost_predicted(shots.size()); std::vector decoding_time_seconds(shots.size()); - std::vector decoder_threads; const stim::DetectorErrorModel original_dem = config.dem.flattened(); - std::vector> error_use_totals(original_dem.count_errors()); + std::vector> decoders(args.num_threads); + std::vector> error_use_per_thread( + args.num_threads, std::vector(original_dem.count_errors())); bool has_obs = args.has_observables(); - std::atomic worker_threads_please_terminate = false; - std::atomic num_worker_threads_active; - for (size_t t = 0; t < args.num_threads; ++t) { - // After this value returns to 0, we know that no further shots will - // transition to finished. - ++num_worker_threads_active; - decoder_threads.push_back(std::thread([&config, &next_unclaimed_shot, &shots, &obs_predicted, - &cost_predicted, &decoding_time_seconds, &finished, - &error_use_totals, &has_obs, - &worker_threads_please_terminate, - &num_worker_threads_active, &original_dem]() { - SimplexDecoder decoder(config); - std::vector error_use(original_dem.count_errors()); - for (size_t shot; - !worker_threads_please_terminate and ((shot = next_unclaimed_shot++) < shots.size());) { + size_t num_errors = 0; + double total_time_seconds = 0; + size_t num_observables = config.dem.count_observables(); + size_t shot = parallel_for_shots_in_order( + shots.size(), args.num_threads, + [&](size_t thread_index, size_t shot_index) { + if (!decoders[thread_index]) { + decoders[thread_index] = std::make_unique(config); + } + auto& decoder = *decoders[thread_index]; + auto& error_use = error_use_per_thread[thread_index]; auto start_time = std::chrono::high_resolution_clock::now(); - decoder.decode_to_errors(shots[shot].hits); + decoder.decode_to_errors(shots[shot_index].hits); auto stop_time = std::chrono::high_resolution_clock::now(); - decoding_time_seconds[shot] = + decoding_time_seconds[shot_index] = std::chrono::duration_cast(stop_time - start_time).count() / 1e6; - obs_predicted[shot] = + obs_predicted[shot_index] = vector_to_u64_mask(decoder.get_flipped_observables(decoder.predicted_errors_buffer)); - cost_predicted[shot] = decoder.cost_from_errors(decoder.predicted_errors_buffer); - if (!has_obs or shots[shot].obs_mask_as_u64() == obs_predicted[shot]) { - // Only count the error uses for shots that did not have a logical - // error, if we know the obs flips. + cost_predicted[shot_index] = decoder.cost_from_errors(decoder.predicted_errors_buffer); + if (!has_obs or shots[shot_index].obs_mask_as_u64() == obs_predicted[shot_index]) { for (size_t ei : decoder.predicted_errors_buffer) { ++error_use[ei]; } } - finished[shot] = true; - } - // Add the error counts to the total - for (size_t ei = 0; ei < config.dem.count_errors(); ++ei) { - error_use_totals[ei] += error_use[ei]; - } - --num_worker_threads_active; - })); - } - size_t num_errors = 0; - double total_time_seconds = 0; - size_t num_observables = config.dem.count_observables(); - size_t shot = 0; - for (; shot < shots.size(); ++shot) { - while (num_worker_threads_active and !finished[shot]) { - // We break once the number of active worker threads is 0, at which point - // there will be no further changes to finished[shot]. - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } - // There can be no further changes to finished[shot]. If it is true, we - // process it and go to the next shot. If it is false, we break now as it - // will never be decoded and no subsequent shots will be decoded. - if (!finished[shot]) { - assert(num_worker_threads_active == 0); - // This and subsequent shots will never become decoded. - break; - } - - if (writer) { - writer->write_bits((uint8_t*)&obs_predicted[shot], num_observables); - writer->write_end(); - } - - if (obs_predicted[shot] != shots[shot].obs_mask_as_u64()) ++num_errors; - - total_time_seconds += decoding_time_seconds[shot]; - - if (args.print_stats) { - std::cout << "num_shots = " << (shot + 1) << " num_errors = " << num_errors - << " total_time_seconds = " << total_time_seconds << std::endl; - std::cout << "cost = " << cost_predicted[shot] << std::endl; - std::cout.flush(); - } + }, + [&](size_t shot_index) { + if (writer) { + writer->write_bits((uint8_t*)&obs_predicted[shot_index], num_observables); + writer->write_end(); + } + if (obs_predicted[shot_index] != shots[shot_index].obs_mask_as_u64()) { + ++num_errors; + } + total_time_seconds += decoding_time_seconds[shot_index]; + if (args.print_stats) { + std::cout << "num_shots = " << (shot_index + 1) << " num_errors = " << num_errors + << " total_time_seconds = " << total_time_seconds << std::endl; + std::cout << "cost = " << cost_predicted[shot_index] << std::endl; + std::cout.flush(); + } + return num_errors < args.max_errors; + }); - if (num_errors >= args.max_errors) { - worker_threads_please_terminate = true; + std::vector error_use_totals(original_dem.count_errors()); + for (const auto& error_use : error_use_per_thread) { + for (size_t ei = 0; ei < error_use_totals.size(); ++ei) { + error_use_totals[ei] += error_use[ei]; } } - for (size_t t = 0; t < args.num_threads; ++t) { - decoder_threads[t].join(); - } if (!args.dem_out_fname.empty()) { std::vector counts(error_use_totals.begin(), error_use_totals.end()); diff --git a/src/tesseract_main.cc b/src/tesseract_main.cc index 65fb4e2..ab5ed9c 100644 --- a/src/tesseract_main.cc +++ b/src/tesseract_main.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -120,6 +121,9 @@ struct Args { "Cannot load observable flips without a corresponding detection " "event data file."); } + if (num_threads == 0) { + throw std::invalid_argument("--threads must be at least 1."); + } if (num_threads > 1000) { throw std::invalid_argument( "There is a maximum limit of 1000 threads imposed to avoid " @@ -424,7 +428,8 @@ int main(int argc, char* argv[]) { program.add_argument("--threads") .help("Number of decoder threads to use") .metavar("N") - .default_value(size_t(std::thread::hardware_concurrency())) + .default_value(size_t( + std::thread::hardware_concurrency() == 0 ? 1 : std::thread::hardware_concurrency())) .store_into(args.num_threads); program.add_argument("--beam") .help("Beam to use for truncation (default = infinity)") @@ -475,105 +480,71 @@ int main(int argc, char* argv[]) { std::vector shots; std::unique_ptr writer; args.extract(config, shots, writer); - std::atomic next_unclaimed_shot; - std::vector> finished(shots.size()); std::vector obs_predicted(shots.size()); std::vector cost_predicted(shots.size()); std::vector decoding_time_seconds(shots.size()); std::vector> low_confidence(shots.size()); - std::vector decoder_threads; const stim::DetectorErrorModel original_dem = config.dem.flattened(); - std::vector> error_use_totals(original_dem.count_errors()); + std::vector> decoders(args.num_threads); + std::vector> error_use_per_thread( + args.num_threads, std::vector(original_dem.count_errors())); bool has_obs = args.has_observables(); - std::atomic worker_threads_please_terminate = false; - std::atomic num_worker_threads_active; - for (size_t t = 0; t < args.num_threads; ++t) { - // After this value returns to 0, we know that no further shots will - // transition to finished. - ++num_worker_threads_active; - decoder_threads.push_back(std::thread([&config, &next_unclaimed_shot, &shots, &obs_predicted, - &cost_predicted, &decoding_time_seconds, &low_confidence, - &finished, &error_use_totals, &has_obs, - &worker_threads_please_terminate, - &num_worker_threads_active, &original_dem]() { - TesseractDecoder decoder(config); - std::vector error_use(original_dem.count_errors()); - for (size_t shot; - !worker_threads_please_terminate and ((shot = next_unclaimed_shot++) < shots.size());) { + size_t num_errors = 0; + size_t num_low_confidence = 0; + double total_time_seconds = 0; + size_t num_observables = config.dem.count_observables(); + size_t shot = parallel_for_shots_in_order( + shots.size(), args.num_threads, + [&](size_t thread_index, size_t shot_index) { + if (!decoders[thread_index]) { + decoders[thread_index] = std::make_unique(config); + } + auto& decoder = *decoders[thread_index]; + auto& error_use = error_use_per_thread[thread_index]; auto start_time = std::chrono::high_resolution_clock::now(); - decoder.decode_to_errors(shots[shot].hits); + decoder.decode_to_errors(shots[shot_index].hits); auto stop_time = std::chrono::high_resolution_clock::now(); - decoding_time_seconds[shot] = + decoding_time_seconds[shot_index] = std::chrono::duration_cast(stop_time - start_time).count() / 1e6; - obs_predicted[shot] = + obs_predicted[shot_index] = vector_to_u64_mask(decoder.get_flipped_observables(decoder.predicted_errors_buffer)); - low_confidence[shot] = decoder.low_confidence_flag; - cost_predicted[shot] = decoder.cost_from_errors(decoder.predicted_errors_buffer); - if (!has_obs or shots[shot].obs_mask_as_u64() == obs_predicted[shot]) { - // Only count the error uses for shots that did not have a logical - // error, if we know the obs flips. + low_confidence[shot_index] = decoder.low_confidence_flag; + cost_predicted[shot_index] = decoder.cost_from_errors(decoder.predicted_errors_buffer); + if (!has_obs or shots[shot_index].obs_mask_as_u64() == obs_predicted[shot_index]) { for (size_t ei : decoder.predicted_errors_buffer) { ++error_use[ei]; } } - finished[shot] = true; - } - // Add the error counts to the total - for (size_t ei = 0; ei < error_use_totals.size(); ++ei) { - error_use_totals[ei] += error_use[ei]; - } - --num_worker_threads_active; - })); - } - size_t num_errors = 0; - size_t num_low_confidence = 0; - double total_time_seconds = 0; - size_t num_observables = config.dem.count_observables(); - size_t shot = 0; - for (; shot < shots.size(); ++shot) { - while (num_worker_threads_active and !finished[shot]) { - // We break once the number of active worker threads is 0, at which point - // there will be no further changes to finished[shot]. - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } - // There can be no further changes to finished[shot]. If it is true, we - // process it and go to the next shot. If it is false, we break now as it - // will never be decoded and no subsequent shots will be decoded. - if (!finished[shot]) { - assert(num_worker_threads_active == 0); - // This and subsequent shots will never become decoded. - break; - } - - if (writer) { - writer->write_bits((uint8_t*)&obs_predicted[shot], num_observables); - writer->write_end(); - } - - if (low_confidence[shot]) { - ++num_low_confidence; - } else if (obs_predicted[shot] != shots[shot].obs_mask_as_u64()) { - ++num_errors; - } - - total_time_seconds += decoding_time_seconds[shot]; - - if (args.print_stats) { - std::cout << "num_shots = " << (shot + 1) << " num_low_confidence = " << num_low_confidence - << " num_errors = " << num_errors << " total_time_seconds = " << total_time_seconds - << std::endl; - std::cout << "cost = " << cost_predicted[shot] << std::endl; - std::cout.flush(); - } + }, + [&](size_t shot_index) { + if (writer) { + writer->write_bits((uint8_t*)&obs_predicted[shot_index], num_observables); + writer->write_end(); + } + if (low_confidence[shot_index]) { + ++num_low_confidence; + } else if (obs_predicted[shot_index] != shots[shot_index].obs_mask_as_u64()) { + ++num_errors; + } + total_time_seconds += decoding_time_seconds[shot_index]; + if (args.print_stats) { + std::cout << "num_shots = " << (shot_index + 1) + << " num_low_confidence = " << num_low_confidence + << " num_errors = " << num_errors + << " total_time_seconds = " << total_time_seconds << std::endl; + std::cout << "cost = " << cost_predicted[shot_index] << std::endl; + std::cout.flush(); + } + return num_errors < args.max_errors; + }); - if (num_errors >= args.max_errors) { - worker_threads_please_terminate = true; + std::vector error_use_totals(original_dem.count_errors()); + for (const auto& error_use : error_use_per_thread) { + for (size_t ei = 0; ei < error_use_totals.size(); ++ei) { + error_use_totals[ei] += error_use[ei]; } } - for (size_t t = 0; t < args.num_threads; ++t) { - decoder_threads[t].join(); - } if (!args.dem_out_fname.empty()) { std::vector counts(error_use_totals.begin(), error_use_totals.end()); diff --git a/src/utils.h b/src/utils.h index 73d7817..fe89b4f 100644 --- a/src/utils.h +++ b/src/utils.h @@ -16,10 +16,13 @@ #define __TESSERACT_UTILS_H__ #include +#include #include +#include #include #include #include +#include #include #include @@ -54,4 +57,59 @@ std::vector get_errors_from_dem(const stim::DetectorErrorModel& d std::vector get_files_recursive(const std::string& directory_path); uint64_t vector_to_u64_mask(const std::vector& v); + +// Applies a shot-wise worker function in parallel while consuming completed +// shots in increasing order. +// +// process_shot(thread_index, shot_index): +// - Runs on worker threads. +// - thread_index is stable for each worker and lies in [0, num_threads). +// +// consume_shot(shot_index): +// - Runs on the caller thread in increasing shot order. +// +// If consume_shot returns false, workers stop claiming new shots but always +// finish any shot they already started. +template +size_t parallel_for_shots_in_order(size_t num_shots, size_t num_threads, ProcessShot&& process_shot, + ConsumeShot&& consume_shot) { + std::atomic next_unclaimed_shot = 0; + std::vector> finished(num_shots); + std::atomic worker_threads_please_terminate = false; + std::atomic num_worker_threads_active = 0; + std::vector workers; + workers.reserve(num_threads); + + for (size_t t = 0; t < num_threads; ++t) { + ++num_worker_threads_active; + workers.emplace_back([&, t]() { + for (size_t shot; + !worker_threads_please_terminate && ((shot = next_unclaimed_shot++) < num_shots);) { + process_shot(t, shot); + finished[shot] = true; + } + --num_worker_threads_active; + }); + } + + size_t shot = 0; + for (; shot < num_shots; ++shot) { + while (num_worker_threads_active && !finished[shot]) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + if (!finished[shot]) { + assert(num_worker_threads_active == 0); + break; + } + if (!consume_shot(shot)) { + worker_threads_please_terminate = true; + } + } + + for (auto& worker : workers) { + worker.join(); + } + return shot; +} + #endif // __TESSERACT_UTILS_H__