Skip to content

Commit fff3c75

Browse files
authored
Refactor bulk parallelism (#221)
The use of atomic counter and termination signaling is slightly non-obvious, and we use the same pattern in an ad-hoc way across both simplex and tesseract mains. The goal is to re-use this also in the python API, as an alternative to sinter. The API looks like this: ```c++ size_t parallel_for_shots_in_order(size_t num_shots, size_t num_threads, ProcessShot&& process_shot, ConsumeShot&& consume_shot) ``` here `process_shot` is called within a worker thread for each shot and `consume_shot` is called on the main thread for each shot (this should return a bool that signals termination when set to `false`). Example usage: ```c++ std::vector<std::unique_ptr<Decoder>> decoders(args.num_threads); std::vector<std::vector<size_t>> error_use_per_thread( args.num_threads, std::vector<size_t>(num_error_terms)); std::vector<Result> results(shots.size()); size_t num_consumed = parallel_for_shots_in_order( shots.size(), args.num_threads, // Process shot runs in parallel, potentially out of order. [&](size_t thread_index, size_t shot_index) { if (!decoders[thread_index]) { decoders[thread_index] = std::make_unique<Decoder>(config); } auto& decoder = *decoders[thread_index]; auto& error_use = error_use_per_thread[thread_index]; results[shot_index] = decoder.decode(shots[shot_index]); if (results[shot_index].count_for_stats) { for (size_t ei : decoder.predicted_errors_buffer) { ++error_use[ei]; } } }, // Consume shot runs on the caller thread, strictly in shot order: 0, 1, 2, ... [&](size_t shot_index) { emit_result(results[shot_index]); return !should_stop_early(results[shot_index]); }); // Optional: merge per-thread scratch after all workers have joined. std::vector<size_t> error_use_totals(num_error_terms); for (const auto& error_use : error_use_per_thread) { for (size_t ei = 0; ei < num_error_terms; ++ei) { error_use_totals[ei] += error_use[ei]; } } ```
1 parent 1b453aa commit fff3c75

File tree

3 files changed

+158
-157
lines changed

3 files changed

+158
-157
lines changed

src/simplex_main.cc

Lines changed: 47 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <argparse/argparse.hpp>
1616
#include <atomic>
1717
#include <fstream>
18+
#include <memory>
1819
#include <nlohmann/json.hpp>
1920
#include <thread>
2021

@@ -107,6 +108,9 @@ struct Args {
107108
"Cannot load observable flips without a corresponding detection "
108109
"event data file.");
109110
}
111+
if (num_threads == 0) {
112+
throw std::invalid_argument("--threads must be at least 1.");
113+
}
110114
if (num_threads > 1000) {
111115
throw std::invalid_argument(
112116
"There is a maximum limit of 1000 threads imposed to avoid "
@@ -367,7 +371,8 @@ int main(int argc, char* argv[]) {
367371
program.add_argument("--threads")
368372
.help("Number of decoder threads to use")
369373
.metavar("N")
370-
.default_value(size_t(std::thread::hardware_concurrency()))
374+
.default_value(size_t(
375+
std::thread::hardware_concurrency() == 0 ? 1 : std::thread::hardware_concurrency()))
371376
.store_into(args.num_threads);
372377
program.add_argument("--parallelize-ilp")
373378
.help(
@@ -416,97 +421,64 @@ int main(int argc, char* argv[]) {
416421
std::vector<stim::SparseShot> shots;
417422
std::unique_ptr<stim::MeasureRecordWriter> writer;
418423
args.extract(config, shots, writer);
419-
std::atomic<size_t> next_unclaimed_shot;
420-
std::vector<std::atomic<bool>> finished(shots.size());
421424
std::vector<uint64_t> obs_predicted(shots.size());
422425
std::vector<double> cost_predicted(shots.size());
423426
std::vector<double> decoding_time_seconds(shots.size());
424-
std::vector<std::thread> decoder_threads;
425427
const stim::DetectorErrorModel original_dem = config.dem.flattened();
426-
std::vector<std::atomic<size_t>> error_use_totals(original_dem.count_errors());
428+
std::vector<std::unique_ptr<SimplexDecoder>> decoders(args.num_threads);
429+
std::vector<std::vector<size_t>> error_use_per_thread(
430+
args.num_threads, std::vector<size_t>(original_dem.count_errors()));
427431
bool has_obs = args.has_observables();
428-
std::atomic<bool> worker_threads_please_terminate = false;
429-
std::atomic<size_t> num_worker_threads_active;
430-
for (size_t t = 0; t < args.num_threads; ++t) {
431-
// After this value returns to 0, we know that no further shots will
432-
// transition to finished.
433-
++num_worker_threads_active;
434-
decoder_threads.push_back(std::thread([&config, &next_unclaimed_shot, &shots, &obs_predicted,
435-
&cost_predicted, &decoding_time_seconds, &finished,
436-
&error_use_totals, &has_obs,
437-
&worker_threads_please_terminate,
438-
&num_worker_threads_active, &original_dem]() {
439-
SimplexDecoder decoder(config);
440-
std::vector<size_t> error_use(original_dem.count_errors());
441-
for (size_t shot;
442-
!worker_threads_please_terminate and ((shot = next_unclaimed_shot++) < shots.size());) {
432+
size_t num_errors = 0;
433+
double total_time_seconds = 0;
434+
size_t num_observables = config.dem.count_observables();
435+
size_t shot = parallel_for_shots_in_order(
436+
shots.size(), args.num_threads,
437+
[&](size_t thread_index, size_t shot_index) {
438+
if (!decoders[thread_index]) {
439+
decoders[thread_index] = std::make_unique<SimplexDecoder>(config);
440+
}
441+
auto& decoder = *decoders[thread_index];
442+
auto& error_use = error_use_per_thread[thread_index];
443443
auto start_time = std::chrono::high_resolution_clock::now();
444-
decoder.decode_to_errors(shots[shot].hits);
444+
decoder.decode_to_errors(shots[shot_index].hits);
445445
auto stop_time = std::chrono::high_resolution_clock::now();
446-
decoding_time_seconds[shot] =
446+
decoding_time_seconds[shot_index] =
447447
std::chrono::duration_cast<std::chrono::microseconds>(stop_time - start_time).count() /
448448
1e6;
449-
obs_predicted[shot] =
449+
obs_predicted[shot_index] =
450450
vector_to_u64_mask(decoder.get_flipped_observables(decoder.predicted_errors_buffer));
451-
cost_predicted[shot] = decoder.cost_from_errors(decoder.predicted_errors_buffer);
452-
if (!has_obs or shots[shot].obs_mask_as_u64() == obs_predicted[shot]) {
453-
// Only count the error uses for shots that did not have a logical
454-
// error, if we know the obs flips.
451+
cost_predicted[shot_index] = decoder.cost_from_errors(decoder.predicted_errors_buffer);
452+
if (!has_obs or shots[shot_index].obs_mask_as_u64() == obs_predicted[shot_index]) {
455453
for (size_t ei : decoder.predicted_errors_buffer) {
456454
++error_use[ei];
457455
}
458456
}
459-
finished[shot] = true;
460-
}
461-
// Add the error counts to the total
462-
for (size_t ei = 0; ei < config.dem.count_errors(); ++ei) {
463-
error_use_totals[ei] += error_use[ei];
464-
}
465-
--num_worker_threads_active;
466-
}));
467-
}
468-
size_t num_errors = 0;
469-
double total_time_seconds = 0;
470-
size_t num_observables = config.dem.count_observables();
471-
size_t shot = 0;
472-
for (; shot < shots.size(); ++shot) {
473-
while (num_worker_threads_active and !finished[shot]) {
474-
// We break once the number of active worker threads is 0, at which point
475-
// there will be no further changes to finished[shot].
476-
std::this_thread::sleep_for(std::chrono::milliseconds(100));
477-
}
478-
// There can be no further changes to finished[shot]. If it is true, we
479-
// process it and go to the next shot. If it is false, we break now as it
480-
// will never be decoded and no subsequent shots will be decoded.
481-
if (!finished[shot]) {
482-
assert(num_worker_threads_active == 0);
483-
// This and subsequent shots will never become decoded.
484-
break;
485-
}
486-
487-
if (writer) {
488-
writer->write_bits((uint8_t*)&obs_predicted[shot], num_observables);
489-
writer->write_end();
490-
}
491-
492-
if (obs_predicted[shot] != shots[shot].obs_mask_as_u64()) ++num_errors;
493-
494-
total_time_seconds += decoding_time_seconds[shot];
495-
496-
if (args.print_stats) {
497-
std::cout << "num_shots = " << (shot + 1) << " num_errors = " << num_errors
498-
<< " total_time_seconds = " << total_time_seconds << std::endl;
499-
std::cout << "cost = " << cost_predicted[shot] << std::endl;
500-
std::cout.flush();
501-
}
457+
},
458+
[&](size_t shot_index) {
459+
if (writer) {
460+
writer->write_bits((uint8_t*)&obs_predicted[shot_index], num_observables);
461+
writer->write_end();
462+
}
463+
if (obs_predicted[shot_index] != shots[shot_index].obs_mask_as_u64()) {
464+
++num_errors;
465+
}
466+
total_time_seconds += decoding_time_seconds[shot_index];
467+
if (args.print_stats) {
468+
std::cout << "num_shots = " << (shot_index + 1) << " num_errors = " << num_errors
469+
<< " total_time_seconds = " << total_time_seconds << std::endl;
470+
std::cout << "cost = " << cost_predicted[shot_index] << std::endl;
471+
std::cout.flush();
472+
}
473+
return num_errors < args.max_errors;
474+
});
502475

503-
if (num_errors >= args.max_errors) {
504-
worker_threads_please_terminate = true;
476+
std::vector<size_t> error_use_totals(original_dem.count_errors());
477+
for (const auto& error_use : error_use_per_thread) {
478+
for (size_t ei = 0; ei < error_use_totals.size(); ++ei) {
479+
error_use_totals[ei] += error_use[ei];
505480
}
506481
}
507-
for (size_t t = 0; t < args.num_threads; ++t) {
508-
decoder_threads[t].join();
509-
}
510482

511483
if (!args.dem_out_fname.empty()) {
512484
std::vector<size_t> counts(error_use_totals.begin(), error_use_totals.end());

src/tesseract_main.cc

Lines changed: 53 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <argparse/argparse.hpp>
1717
#include <atomic>
1818
#include <fstream>
19+
#include <memory>
1920
#include <nlohmann/json.hpp>
2021
#include <numeric>
2122
#include <queue>
@@ -120,6 +121,9 @@ struct Args {
120121
"Cannot load observable flips without a corresponding detection "
121122
"event data file.");
122123
}
124+
if (num_threads == 0) {
125+
throw std::invalid_argument("--threads must be at least 1.");
126+
}
123127
if (num_threads > 1000) {
124128
throw std::invalid_argument(
125129
"There is a maximum limit of 1000 threads imposed to avoid "
@@ -424,7 +428,8 @@ int main(int argc, char* argv[]) {
424428
program.add_argument("--threads")
425429
.help("Number of decoder threads to use")
426430
.metavar("N")
427-
.default_value(size_t(std::thread::hardware_concurrency()))
431+
.default_value(size_t(
432+
std::thread::hardware_concurrency() == 0 ? 1 : std::thread::hardware_concurrency()))
428433
.store_into(args.num_threads);
429434
program.add_argument("--beam")
430435
.help("Beam to use for truncation (default = infinity)")
@@ -475,105 +480,71 @@ int main(int argc, char* argv[]) {
475480
std::vector<stim::SparseShot> shots;
476481
std::unique_ptr<stim::MeasureRecordWriter> writer;
477482
args.extract(config, shots, writer);
478-
std::atomic<size_t> next_unclaimed_shot;
479-
std::vector<std::atomic<bool>> finished(shots.size());
480483
std::vector<uint64_t> obs_predicted(shots.size());
481484
std::vector<double> cost_predicted(shots.size());
482485
std::vector<double> decoding_time_seconds(shots.size());
483486
std::vector<std::atomic<bool>> low_confidence(shots.size());
484-
std::vector<std::thread> decoder_threads;
485487
const stim::DetectorErrorModel original_dem = config.dem.flattened();
486-
std::vector<std::atomic<size_t>> error_use_totals(original_dem.count_errors());
488+
std::vector<std::unique_ptr<TesseractDecoder>> decoders(args.num_threads);
489+
std::vector<std::vector<size_t>> error_use_per_thread(
490+
args.num_threads, std::vector<size_t>(original_dem.count_errors()));
487491
bool has_obs = args.has_observables();
488-
std::atomic<bool> worker_threads_please_terminate = false;
489-
std::atomic<size_t> num_worker_threads_active;
490-
for (size_t t = 0; t < args.num_threads; ++t) {
491-
// After this value returns to 0, we know that no further shots will
492-
// transition to finished.
493-
++num_worker_threads_active;
494-
decoder_threads.push_back(std::thread([&config, &next_unclaimed_shot, &shots, &obs_predicted,
495-
&cost_predicted, &decoding_time_seconds, &low_confidence,
496-
&finished, &error_use_totals, &has_obs,
497-
&worker_threads_please_terminate,
498-
&num_worker_threads_active, &original_dem]() {
499-
TesseractDecoder decoder(config);
500-
std::vector<size_t> error_use(original_dem.count_errors());
501-
for (size_t shot;
502-
!worker_threads_please_terminate and ((shot = next_unclaimed_shot++) < shots.size());) {
492+
size_t num_errors = 0;
493+
size_t num_low_confidence = 0;
494+
double total_time_seconds = 0;
495+
size_t num_observables = config.dem.count_observables();
496+
size_t shot = parallel_for_shots_in_order(
497+
shots.size(), args.num_threads,
498+
[&](size_t thread_index, size_t shot_index) {
499+
if (!decoders[thread_index]) {
500+
decoders[thread_index] = std::make_unique<TesseractDecoder>(config);
501+
}
502+
auto& decoder = *decoders[thread_index];
503+
auto& error_use = error_use_per_thread[thread_index];
503504
auto start_time = std::chrono::high_resolution_clock::now();
504-
decoder.decode_to_errors(shots[shot].hits);
505+
decoder.decode_to_errors(shots[shot_index].hits);
505506
auto stop_time = std::chrono::high_resolution_clock::now();
506-
decoding_time_seconds[shot] =
507+
decoding_time_seconds[shot_index] =
507508
std::chrono::duration_cast<std::chrono::microseconds>(stop_time - start_time).count() /
508509
1e6;
509-
obs_predicted[shot] =
510+
obs_predicted[shot_index] =
510511
vector_to_u64_mask(decoder.get_flipped_observables(decoder.predicted_errors_buffer));
511-
low_confidence[shot] = decoder.low_confidence_flag;
512-
cost_predicted[shot] = decoder.cost_from_errors(decoder.predicted_errors_buffer);
513-
if (!has_obs or shots[shot].obs_mask_as_u64() == obs_predicted[shot]) {
514-
// Only count the error uses for shots that did not have a logical
515-
// error, if we know the obs flips.
512+
low_confidence[shot_index] = decoder.low_confidence_flag;
513+
cost_predicted[shot_index] = decoder.cost_from_errors(decoder.predicted_errors_buffer);
514+
if (!has_obs or shots[shot_index].obs_mask_as_u64() == obs_predicted[shot_index]) {
516515
for (size_t ei : decoder.predicted_errors_buffer) {
517516
++error_use[ei];
518517
}
519518
}
520-
finished[shot] = true;
521-
}
522-
// Add the error counts to the total
523-
for (size_t ei = 0; ei < error_use_totals.size(); ++ei) {
524-
error_use_totals[ei] += error_use[ei];
525-
}
526-
--num_worker_threads_active;
527-
}));
528-
}
529-
size_t num_errors = 0;
530-
size_t num_low_confidence = 0;
531-
double total_time_seconds = 0;
532-
size_t num_observables = config.dem.count_observables();
533-
size_t shot = 0;
534-
for (; shot < shots.size(); ++shot) {
535-
while (num_worker_threads_active and !finished[shot]) {
536-
// We break once the number of active worker threads is 0, at which point
537-
// there will be no further changes to finished[shot].
538-
std::this_thread::sleep_for(std::chrono::milliseconds(100));
539-
}
540-
// There can be no further changes to finished[shot]. If it is true, we
541-
// process it and go to the next shot. If it is false, we break now as it
542-
// will never be decoded and no subsequent shots will be decoded.
543-
if (!finished[shot]) {
544-
assert(num_worker_threads_active == 0);
545-
// This and subsequent shots will never become decoded.
546-
break;
547-
}
548-
549-
if (writer) {
550-
writer->write_bits((uint8_t*)&obs_predicted[shot], num_observables);
551-
writer->write_end();
552-
}
553-
554-
if (low_confidence[shot]) {
555-
++num_low_confidence;
556-
} else if (obs_predicted[shot] != shots[shot].obs_mask_as_u64()) {
557-
++num_errors;
558-
}
559-
560-
total_time_seconds += decoding_time_seconds[shot];
561-
562-
if (args.print_stats) {
563-
std::cout << "num_shots = " << (shot + 1) << " num_low_confidence = " << num_low_confidence
564-
<< " num_errors = " << num_errors << " total_time_seconds = " << total_time_seconds
565-
<< std::endl;
566-
std::cout << "cost = " << cost_predicted[shot] << std::endl;
567-
std::cout.flush();
568-
}
519+
},
520+
[&](size_t shot_index) {
521+
if (writer) {
522+
writer->write_bits((uint8_t*)&obs_predicted[shot_index], num_observables);
523+
writer->write_end();
524+
}
525+
if (low_confidence[shot_index]) {
526+
++num_low_confidence;
527+
} else if (obs_predicted[shot_index] != shots[shot_index].obs_mask_as_u64()) {
528+
++num_errors;
529+
}
530+
total_time_seconds += decoding_time_seconds[shot_index];
531+
if (args.print_stats) {
532+
std::cout << "num_shots = " << (shot_index + 1)
533+
<< " num_low_confidence = " << num_low_confidence
534+
<< " num_errors = " << num_errors
535+
<< " total_time_seconds = " << total_time_seconds << std::endl;
536+
std::cout << "cost = " << cost_predicted[shot_index] << std::endl;
537+
std::cout.flush();
538+
}
539+
return num_errors < args.max_errors;
540+
});
569541

570-
if (num_errors >= args.max_errors) {
571-
worker_threads_please_terminate = true;
542+
std::vector<size_t> error_use_totals(original_dem.count_errors());
543+
for (const auto& error_use : error_use_per_thread) {
544+
for (size_t ei = 0; ei < error_use_totals.size(); ++ei) {
545+
error_use_totals[ei] += error_use[ei];
572546
}
573547
}
574-
for (size_t t = 0; t < args.num_threads; ++t) {
575-
decoder_threads[t].join();
576-
}
577548

578549
if (!args.dem_out_fname.empty()) {
579550
std::vector<size_t> counts(error_use_totals.begin(), error_use_totals.end());

0 commit comments

Comments
 (0)