Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 47 additions & 75 deletions src/simplex_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <argparse/argparse.hpp>
#include <atomic>
#include <fstream>
#include <memory>
#include <nlohmann/json.hpp>
#include <thread>

Expand Down Expand Up @@ -107,6 +108,9 @@ struct Args {
"Cannot load observable flips without a corresponding detection "
"event data file.");
}
if (num_threads == 0) {
throw std::invalid_argument("--threads must be at least 1.");
}
if (num_threads > 1000) {
throw std::invalid_argument(
"There is a maximum limit of 1000 threads imposed to avoid "
Expand Down Expand Up @@ -367,7 +371,8 @@ int main(int argc, char* argv[]) {
program.add_argument("--threads")
.help("Number of decoder threads to use")
.metavar("N")
.default_value(size_t(std::thread::hardware_concurrency()))
.default_value(size_t(
std::thread::hardware_concurrency() == 0 ? 1 : std::thread::hardware_concurrency()))
.store_into(args.num_threads);
program.add_argument("--parallelize-ilp")
.help(
Expand Down Expand Up @@ -416,97 +421,64 @@ int main(int argc, char* argv[]) {
std::vector<stim::SparseShot> shots;
std::unique_ptr<stim::MeasureRecordWriter> writer;
args.extract(config, shots, writer);
std::atomic<size_t> next_unclaimed_shot;
std::vector<std::atomic<bool>> finished(shots.size());
std::vector<uint64_t> obs_predicted(shots.size());
std::vector<double> cost_predicted(shots.size());
std::vector<double> decoding_time_seconds(shots.size());
std::vector<std::thread> decoder_threads;
const stim::DetectorErrorModel original_dem = config.dem.flattened();
std::vector<std::atomic<size_t>> error_use_totals(original_dem.count_errors());
std::vector<std::unique_ptr<SimplexDecoder>> decoders(args.num_threads);
std::vector<std::vector<size_t>> error_use_per_thread(
args.num_threads, std::vector<size_t>(original_dem.count_errors()));
bool has_obs = args.has_observables();
std::atomic<bool> worker_threads_please_terminate = false;
std::atomic<size_t> num_worker_threads_active;
for (size_t t = 0; t < args.num_threads; ++t) {
// After this value returns to 0, we know that no further shots will
// transition to finished.
++num_worker_threads_active;
decoder_threads.push_back(std::thread([&config, &next_unclaimed_shot, &shots, &obs_predicted,
&cost_predicted, &decoding_time_seconds, &finished,
&error_use_totals, &has_obs,
&worker_threads_please_terminate,
&num_worker_threads_active, &original_dem]() {
SimplexDecoder decoder(config);
std::vector<size_t> error_use(original_dem.count_errors());
for (size_t shot;
!worker_threads_please_terminate and ((shot = next_unclaimed_shot++) < shots.size());) {
size_t num_errors = 0;
double total_time_seconds = 0;
size_t num_observables = config.dem.count_observables();
size_t shot = parallel_for_shots_in_order(
shots.size(), args.num_threads,
[&](size_t thread_index, size_t shot_index) {
if (!decoders[thread_index]) {
decoders[thread_index] = std::make_unique<SimplexDecoder>(config);
}
auto& decoder = *decoders[thread_index];
auto& error_use = error_use_per_thread[thread_index];
auto start_time = std::chrono::high_resolution_clock::now();
decoder.decode_to_errors(shots[shot].hits);
decoder.decode_to_errors(shots[shot_index].hits);
auto stop_time = std::chrono::high_resolution_clock::now();
decoding_time_seconds[shot] =
decoding_time_seconds[shot_index] =
std::chrono::duration_cast<std::chrono::microseconds>(stop_time - start_time).count() /
1e6;
obs_predicted[shot] =
obs_predicted[shot_index] =
vector_to_u64_mask(decoder.get_flipped_observables(decoder.predicted_errors_buffer));
cost_predicted[shot] = decoder.cost_from_errors(decoder.predicted_errors_buffer);
if (!has_obs or shots[shot].obs_mask_as_u64() == obs_predicted[shot]) {
// Only count the error uses for shots that did not have a logical
// error, if we know the obs flips.
cost_predicted[shot_index] = decoder.cost_from_errors(decoder.predicted_errors_buffer);
if (!has_obs or shots[shot_index].obs_mask_as_u64() == obs_predicted[shot_index]) {
for (size_t ei : decoder.predicted_errors_buffer) {
++error_use[ei];
}
}
finished[shot] = true;
}
// Add the error counts to the total
for (size_t ei = 0; ei < config.dem.count_errors(); ++ei) {
error_use_totals[ei] += error_use[ei];
}
--num_worker_threads_active;
}));
}
size_t num_errors = 0;
double total_time_seconds = 0;
size_t num_observables = config.dem.count_observables();
size_t shot = 0;
for (; shot < shots.size(); ++shot) {
while (num_worker_threads_active and !finished[shot]) {
// We break once the number of active worker threads is 0, at which point
// there will be no further changes to finished[shot].
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
// There can be no further changes to finished[shot]. If it is true, we
// process it and go to the next shot. If it is false, we break now as it
// will never be decoded and no subsequent shots will be decoded.
if (!finished[shot]) {
assert(num_worker_threads_active == 0);
// This and subsequent shots will never become decoded.
break;
}

if (writer) {
writer->write_bits((uint8_t*)&obs_predicted[shot], num_observables);
writer->write_end();
}

if (obs_predicted[shot] != shots[shot].obs_mask_as_u64()) ++num_errors;

total_time_seconds += decoding_time_seconds[shot];

if (args.print_stats) {
std::cout << "num_shots = " << (shot + 1) << " num_errors = " << num_errors
<< " total_time_seconds = " << total_time_seconds << std::endl;
std::cout << "cost = " << cost_predicted[shot] << std::endl;
std::cout.flush();
}
},
[&](size_t shot_index) {
if (writer) {
writer->write_bits((uint8_t*)&obs_predicted[shot_index], num_observables);
writer->write_end();
}
if (obs_predicted[shot_index] != shots[shot_index].obs_mask_as_u64()) {
++num_errors;
}
total_time_seconds += decoding_time_seconds[shot_index];
if (args.print_stats) {
std::cout << "num_shots = " << (shot_index + 1) << " num_errors = " << num_errors
<< " total_time_seconds = " << total_time_seconds << std::endl;
std::cout << "cost = " << cost_predicted[shot_index] << std::endl;
std::cout.flush();
}
return num_errors < args.max_errors;
});

if (num_errors >= args.max_errors) {
worker_threads_please_terminate = true;
std::vector<size_t> error_use_totals(original_dem.count_errors());
for (const auto& error_use : error_use_per_thread) {
for (size_t ei = 0; ei < error_use_totals.size(); ++ei) {
error_use_totals[ei] += error_use[ei];
}
}
for (size_t t = 0; t < args.num_threads; ++t) {
decoder_threads[t].join();
}

if (!args.dem_out_fname.empty()) {
std::vector<size_t> counts(error_use_totals.begin(), error_use_totals.end());
Expand Down
135 changes: 53 additions & 82 deletions src/tesseract_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <argparse/argparse.hpp>
#include <atomic>
#include <fstream>
#include <memory>
#include <nlohmann/json.hpp>
#include <numeric>
#include <queue>
Expand Down Expand Up @@ -120,6 +121,9 @@ struct Args {
"Cannot load observable flips without a corresponding detection "
"event data file.");
}
if (num_threads == 0) {
throw std::invalid_argument("--threads must be at least 1.");
}
if (num_threads > 1000) {
throw std::invalid_argument(
"There is a maximum limit of 1000 threads imposed to avoid "
Expand Down Expand Up @@ -424,7 +428,8 @@ int main(int argc, char* argv[]) {
program.add_argument("--threads")
.help("Number of decoder threads to use")
.metavar("N")
.default_value(size_t(std::thread::hardware_concurrency()))
.default_value(size_t(
std::thread::hardware_concurrency() == 0 ? 1 : std::thread::hardware_concurrency()))
.store_into(args.num_threads);
program.add_argument("--beam")
.help("Beam to use for truncation (default = infinity)")
Expand Down Expand Up @@ -475,105 +480,71 @@ int main(int argc, char* argv[]) {
std::vector<stim::SparseShot> shots;
std::unique_ptr<stim::MeasureRecordWriter> writer;
args.extract(config, shots, writer);
std::atomic<size_t> next_unclaimed_shot;
std::vector<std::atomic<bool>> finished(shots.size());
std::vector<uint64_t> obs_predicted(shots.size());
std::vector<double> cost_predicted(shots.size());
std::vector<double> decoding_time_seconds(shots.size());
std::vector<std::atomic<bool>> low_confidence(shots.size());
std::vector<std::thread> decoder_threads;
const stim::DetectorErrorModel original_dem = config.dem.flattened();
std::vector<std::atomic<size_t>> error_use_totals(original_dem.count_errors());
std::vector<std::unique_ptr<TesseractDecoder>> decoders(args.num_threads);
std::vector<std::vector<size_t>> error_use_per_thread(
args.num_threads, std::vector<size_t>(original_dem.count_errors()));
bool has_obs = args.has_observables();
std::atomic<bool> worker_threads_please_terminate = false;
std::atomic<size_t> num_worker_threads_active;
for (size_t t = 0; t < args.num_threads; ++t) {
// After this value returns to 0, we know that no further shots will
// transition to finished.
++num_worker_threads_active;
decoder_threads.push_back(std::thread([&config, &next_unclaimed_shot, &shots, &obs_predicted,
&cost_predicted, &decoding_time_seconds, &low_confidence,
&finished, &error_use_totals, &has_obs,
&worker_threads_please_terminate,
&num_worker_threads_active, &original_dem]() {
TesseractDecoder decoder(config);
std::vector<size_t> error_use(original_dem.count_errors());
for (size_t shot;
!worker_threads_please_terminate and ((shot = next_unclaimed_shot++) < shots.size());) {
size_t num_errors = 0;
size_t num_low_confidence = 0;
double total_time_seconds = 0;
size_t num_observables = config.dem.count_observables();
size_t shot = parallel_for_shots_in_order(
shots.size(), args.num_threads,
[&](size_t thread_index, size_t shot_index) {
if (!decoders[thread_index]) {
decoders[thread_index] = std::make_unique<TesseractDecoder>(config);
}
auto& decoder = *decoders[thread_index];
auto& error_use = error_use_per_thread[thread_index];
auto start_time = std::chrono::high_resolution_clock::now();
decoder.decode_to_errors(shots[shot].hits);
decoder.decode_to_errors(shots[shot_index].hits);
auto stop_time = std::chrono::high_resolution_clock::now();
decoding_time_seconds[shot] =
decoding_time_seconds[shot_index] =
std::chrono::duration_cast<std::chrono::microseconds>(stop_time - start_time).count() /
1e6;
obs_predicted[shot] =
obs_predicted[shot_index] =
vector_to_u64_mask(decoder.get_flipped_observables(decoder.predicted_errors_buffer));
low_confidence[shot] = decoder.low_confidence_flag;
cost_predicted[shot] = decoder.cost_from_errors(decoder.predicted_errors_buffer);
if (!has_obs or shots[shot].obs_mask_as_u64() == obs_predicted[shot]) {
// Only count the error uses for shots that did not have a logical
// error, if we know the obs flips.
low_confidence[shot_index] = decoder.low_confidence_flag;
cost_predicted[shot_index] = decoder.cost_from_errors(decoder.predicted_errors_buffer);
if (!has_obs or shots[shot_index].obs_mask_as_u64() == obs_predicted[shot_index]) {
for (size_t ei : decoder.predicted_errors_buffer) {
++error_use[ei];
}
}
finished[shot] = true;
}
// Add the error counts to the total
for (size_t ei = 0; ei < error_use_totals.size(); ++ei) {
error_use_totals[ei] += error_use[ei];
}
--num_worker_threads_active;
}));
}
size_t num_errors = 0;
size_t num_low_confidence = 0;
double total_time_seconds = 0;
size_t num_observables = config.dem.count_observables();
size_t shot = 0;
for (; shot < shots.size(); ++shot) {
while (num_worker_threads_active and !finished[shot]) {
// We break once the number of active worker threads is 0, at which point
// there will be no further changes to finished[shot].
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
// There can be no further changes to finished[shot]. If it is true, we
// process it and go to the next shot. If it is false, we break now as it
// will never be decoded and no subsequent shots will be decoded.
if (!finished[shot]) {
assert(num_worker_threads_active == 0);
// This and subsequent shots will never become decoded.
break;
}

if (writer) {
writer->write_bits((uint8_t*)&obs_predicted[shot], num_observables);
writer->write_end();
}

if (low_confidence[shot]) {
++num_low_confidence;
} else if (obs_predicted[shot] != shots[shot].obs_mask_as_u64()) {
++num_errors;
}

total_time_seconds += decoding_time_seconds[shot];

if (args.print_stats) {
std::cout << "num_shots = " << (shot + 1) << " num_low_confidence = " << num_low_confidence
<< " num_errors = " << num_errors << " total_time_seconds = " << total_time_seconds
<< std::endl;
std::cout << "cost = " << cost_predicted[shot] << std::endl;
std::cout.flush();
}
},
[&](size_t shot_index) {
if (writer) {
writer->write_bits((uint8_t*)&obs_predicted[shot_index], num_observables);
writer->write_end();
}
if (low_confidence[shot_index]) {
++num_low_confidence;
} else if (obs_predicted[shot_index] != shots[shot_index].obs_mask_as_u64()) {
++num_errors;
}
total_time_seconds += decoding_time_seconds[shot_index];
if (args.print_stats) {
std::cout << "num_shots = " << (shot_index + 1)
<< " num_low_confidence = " << num_low_confidence
<< " num_errors = " << num_errors
<< " total_time_seconds = " << total_time_seconds << std::endl;
std::cout << "cost = " << cost_predicted[shot_index] << std::endl;
std::cout.flush();
}
return num_errors < args.max_errors;
});

if (num_errors >= args.max_errors) {
worker_threads_please_terminate = true;
std::vector<size_t> error_use_totals(original_dem.count_errors());
for (const auto& error_use : error_use_per_thread) {
for (size_t ei = 0; ei < error_use_totals.size(); ++ei) {
error_use_totals[ei] += error_use[ei];
}
}
for (size_t t = 0; t < args.num_threads; ++t) {
decoder_threads[t].join();
}

if (!args.dem_out_fname.empty()) {
std::vector<size_t> counts(error_use_totals.begin(), error_use_totals.end());
Expand Down
Loading
Loading