From 86382592b42451d2865ec7f82699af44e4e7788d Mon Sep 17 00:00:00 2001 From: noajshu Date: Wed, 18 Mar 2026 20:58:07 +0000 Subject: [PATCH 01/10] ignore module.bazel.lock --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index dbfe316..65d92e3 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,5 @@ user.bazelrc # Ignore python extension module produced by CMake. src/tesseract_decoder*.so + +MODULE.bazel.lock From 3764f82a5eaa2574fb5d0f35f221746c3497c2aa Mon Sep 17 00:00:00 2001 From: noajshu Date: Wed, 18 Mar 2026 20:58:29 +0000 Subject: [PATCH 02/10] add ignore_extra_checks=False kwarg to the decomposition helpers --- src/py/_tesseract_py_util/decompose_errors.py | 51 +++++++++++++++---- .../decompose_errors_test.py | 31 +++++++++++ src/py/_tesseract_py_util/demutil.py | 12 +++-- src/py/_tesseract_py_util/demutil_test.py | 25 +++++++++ 4 files changed, 106 insertions(+), 13 deletions(-) diff --git a/src/py/_tesseract_py_util/decompose_errors.py b/src/py/_tesseract_py_util/decompose_errors.py index d20fc6a..6cf9858 100644 --- a/src/py/_tesseract_py_util/decompose_errors.py +++ b/src/py/_tesseract_py_util/decompose_errors.py @@ -89,7 +89,9 @@ def get_component_obs_matching_undecomposed_obs( def decompose_errors_using_detector_assignment( - dem: stim.DetectorErrorModel, detector_component_func: Callable[[int], int] + dem: stim.DetectorErrorModel, + detector_component_func: Callable[[int], int], + disable_extra_checks: bool = False, ) -> stim.DetectorErrorModel: """Decomposes errors in the detector error model `dem` based on an assignment of detectors to components by the function `detector_component_func`. @@ -112,6 +114,10 @@ def decompose_errors_using_detector_assignment( detector_component_func : Callable[[int], int] A function that maps a detector id to its component. i.e. This could map a detector index to 0 if it is X-type or to 1 if it is Z-type. + disable_extra_checks : bool + If True, decomposition will proceed even if a component of an error is not + present as its own error in the dem. In this case, the component is + assumed to have no observables. Returns ------- @@ -155,11 +161,14 @@ def decompose_errors_using_detector_assignment( sorted(d for d in detectors if det_components[d] == c) ) if component_dets not in single_component_dets_to_obs: - raise ValueError( - f"The dem error `{instruction}` needs to be decomposed into components, however " - f"the component with detectors {component_dets} is not present as its own error " - "in the dem." - ) + if disable_extra_checks: + single_component_dets_to_obs[component_dets].add(tuple()) + else: + raise ValueError( + f"The dem error `{instruction}` needs to be decomposed into components, however " + f"the component with detectors {component_dets} is not present as its own error " + "in the dem." + ) dets_by_component.append(component_dets) obs_options_by_component.append( single_component_dets_to_obs[component_dets] @@ -202,7 +211,9 @@ def decompose_errors_using_detector_assignment( def decompose_errors_using_detector_coordinate_assignment( - dem: stim.DetectorErrorModel, coord_to_component_func: Callable[[list[float]], int] + dem: stim.DetectorErrorModel, + coord_to_component_func: Callable[[list[float]], int], + disable_extra_checks: bool = False, ) -> stim.DetectorErrorModel: """Decomposes errors in the detector error model `dem` based on an assignment of detectors to components using a function of the detector coordinates. @@ -225,6 +236,10 @@ def decompose_errors_using_detector_coordinate_assignment( A function that coordinates of a detector to an integer corresponding to the index of a component, to be used for the decomposition. The coordinates are provided as a list of floats. + disable_extra_checks : bool + If True, decomposition will proceed even if a component of an error is not + present as its own error in the dem. In this case, the component is + assumed to have no observables. Returns ------- @@ -237,7 +252,9 @@ def component_using_coords(detector_id: int) -> int: return coord_to_component_func(detector_coords[detector_id]) return decompose_errors_using_detector_assignment( - dem=dem, detector_component_func=component_using_coords + dem=dem, + detector_component_func=component_using_coords, + disable_extra_checks=disable_extra_checks, ) @@ -252,6 +269,7 @@ def detector_coord_to_basis_for_stim_surface_code_convention(coord: tuple[int]) def decompose_errors_using_last_coordinate_index( dem: stim.DetectorErrorModel, + disable_extra_checks: bool = False, ) -> stim.DetectorErrorModel: """Decomposes errors in the detector error model `dem` based on an assignment of detectors to components by the last element of each detector coordinate. @@ -269,6 +287,10 @@ def decompose_errors_using_last_coordinate_index( ---------- dem : stim.DetectorErrorModel The detector error model to decompose. + disable_extra_checks : bool + If True, decomposition will proceed even if a component of an error is not + present as its own error in the dem. In this case, the component is + assumed to have no observables. Returns ------- @@ -281,12 +303,15 @@ def last_coordinate_component(detector_id: int) -> int: return detector_coords[detector_id][-1] return decompose_errors_using_detector_assignment( - dem=dem, detector_component_func=last_coordinate_component + dem=dem, + detector_component_func=last_coordinate_component, + disable_extra_checks=disable_extra_checks, ) def decompose_errors_for_stim_surface_code_coords( dem: stim.DetectorErrorModel, + disable_extra_checks: bool = False, ) -> stim.DetectorErrorModel: """Decomposes the errors in the dem, such that each component of a decomposed error only triggers detectors of one basis (X or Z) @@ -302,6 +327,10 @@ def decompose_errors_for_stim_surface_code_coords( ---------- dem : stim.DetectorErrorModel The detector error model to decompose + disable_extra_checks : bool + If True, decomposition will proceed even if a component of an error is not + present as its own error in the dem. In this case, the component is + assumed to have no observables. Returns ------- @@ -316,7 +345,9 @@ def stim_surface_code_det_component(detector_id: int) -> int: ) return decompose_errors_using_detector_assignment( - dem=dem, detector_component_func=stim_surface_code_det_component + dem=dem, + detector_component_func=stim_surface_code_det_component, + disable_extra_checks=disable_extra_checks, ) diff --git a/src/py/_tesseract_py_util/decompose_errors_test.py b/src/py/_tesseract_py_util/decompose_errors_test.py index 45ac789..983ef2b 100644 --- a/src/py/_tesseract_py_util/decompose_errors_test.py +++ b/src/py/_tesseract_py_util/decompose_errors_test.py @@ -185,6 +185,37 @@ def test_undecompose_errors_surface_code(): assert dem_decomposed_using_coords_func == dem_decomposed_using_coords +def test_decompose_errors_disable_extra_checks(): + dem = stim.DetectorErrorModel(""" +detector(0) D0 +detector(1) D1 +# Error with multiple components (D0 and D1) +error(0.1) D0 D1 +# D0 exists as a standalone error +error(0.1) D0 +# D1 DOES NOT exist as a standalone error +""") + + # Should fail by default + with pytest.raises(ValueError, match="needs to be decomposed into components"): + decompose_errors_using_last_coordinate_index(dem) + + # Should pass with disable_extra_checks=True + decomposed_dem = decompose_errors_using_last_coordinate_index(dem, disable_extra_checks=True) + + # Check that D0 D1 was decomposed. + # Since D1 doesn't exist, it should be treated as having no observables. + # D0 exists as error(0.1) D0, so it has no observables either. + # So D0 D1 should decompose to D0 ^ D1. + expected_dem = stim.DetectorErrorModel(""" +detector(0) D0 +detector(1) D1 +error(0.1) D0 ^ D1 +error(0.1) D0 +""") + assert str(decomposed_dem) == str(expected_dem) + + def test_undecompose_errors_with_repeat_block(): dem = stim.DetectorErrorModel("""error(0.1) D2 D5 ^ D10 L1 repeat 10 { diff --git a/src/py/_tesseract_py_util/demutil.py b/src/py/_tesseract_py_util/demutil.py index 6596340..29a5839 100644 --- a/src/py/_tesseract_py_util/demutil.py +++ b/src/py/_tesseract_py_util/demutil.py @@ -23,13 +23,19 @@ def decompose_errors( - dem: stim.DetectorErrorModel, method: str = "stim-surfacecode-coords" + dem: stim.DetectorErrorModel, + method: str = "stim-surfacecode-coords", + disable_extra_checks: bool = False, ) -> stim.DetectorErrorModel: """Dispatch decomposition strategy by method name.""" if method == "stim-surfacecode-coords": - return decompose_errors_for_stim_surface_code_coords(dem) + return decompose_errors_for_stim_surface_code_coords( + dem, disable_extra_checks=disable_extra_checks + ) if method == "last-coordinate-index": - return decompose_errors_using_last_coordinate_index(dem) + return decompose_errors_using_last_coordinate_index( + dem, disable_extra_checks=disable_extra_checks + ) raise ValueError( "Unknown decomposition method " f"{method!r}. Expected 'stim-surfacecode-coords' or 'last-coordinate-index'." diff --git a/src/py/_tesseract_py_util/demutil_test.py b/src/py/_tesseract_py_util/demutil_test.py index 916f776..a4847d1 100644 --- a/src/py/_tesseract_py_util/demutil_test.py +++ b/src/py/_tesseract_py_util/demutil_test.py @@ -67,5 +67,30 @@ def test_regeneralize_spatial_dem_averages_template_probabilities(): assert probs == pytest.approx([0.2, 0.3]) +def test_decompose_errors_top_level_disable_extra_checks(): + dem = stim.DetectorErrorModel(""" +detector(0) D0 +detector(1) D1 +# Error with multiple components (D0 and D1) +error(0.1) D0 D1 +# D0 exists as a standalone error +error(0.1) D0 +# D1 DOES NOT exist as a standalone error +""") + + # Should pass with disable_extra_checks=True + decomposed_dem = demutil.decompose_errors( + dem, method="last-coordinate-index", disable_extra_checks=True + ) + + expected_dem = stim.DetectorErrorModel(""" +detector(0) D0 +detector(1) D1 +error(0.1) D0 ^ D1 +error(0.1) D0 +""") + assert str(decomposed_dem) == str(expected_dem) + + if __name__ == "__main__": raise SystemExit(pytest.main([__file__])) From 7b1325d43a416e6723812198123284904470a313 Mon Sep 17 00:00:00 2001 From: noajshu Date: Wed, 18 Mar 2026 21:47:06 +0000 Subject: [PATCH 03/10] replace the skip functionaltiy with omit these errors --- src/py/_tesseract_py_util/decompose_errors.py | 51 ++++++++++--------- .../decompose_errors_test.py | 11 ++-- src/py/_tesseract_py_util/demutil.py | 6 +-- src/py/_tesseract_py_util/demutil_test.py | 7 ++- 4 files changed, 35 insertions(+), 40 deletions(-) diff --git a/src/py/_tesseract_py_util/decompose_errors.py b/src/py/_tesseract_py_util/decompose_errors.py index 6cf9858..8ed3302 100644 --- a/src/py/_tesseract_py_util/decompose_errors.py +++ b/src/py/_tesseract_py_util/decompose_errors.py @@ -91,7 +91,7 @@ def get_component_obs_matching_undecomposed_obs( def decompose_errors_using_detector_assignment( dem: stim.DetectorErrorModel, detector_component_func: Callable[[int], int], - disable_extra_checks: bool = False, + strip_undecomposable_errors: bool = False, ) -> stim.DetectorErrorModel: """Decomposes errors in the detector error model `dem` based on an assignment of detectors to components by the function `detector_component_func`. @@ -114,10 +114,9 @@ def decompose_errors_using_detector_assignment( detector_component_func : Callable[[int], int] A function that maps a detector id to its component. i.e. This could map a detector index to 0 if it is X-type or to 1 if it is Z-type. - disable_extra_checks : bool - If True, decomposition will proceed even if a component of an error is not - present as its own error in the dem. In this case, the component is - assumed to have no observables. + strip_undecomposable_errors : bool + If True, errors that cannot be decomposed due to a missing component error + will be stripped from the output DEM instead of raising a ValueError. Returns ------- @@ -156,13 +155,15 @@ def decompose_errors_using_detector_assignment( dets_by_component = [] obs_options_by_component = [] + is_undecomposable = False for c in unique_components: component_dets = tuple( sorted(d for d in detectors if det_components[d] == c) ) if component_dets not in single_component_dets_to_obs: - if disable_extra_checks: - single_component_dets_to_obs[component_dets].add(tuple()) + if strip_undecomposable_errors: + is_undecomposable = True + break else: raise ValueError( f"The dem error `{instruction}` needs to be decomposed into components, however " @@ -174,6 +175,9 @@ def decompose_errors_using_detector_assignment( single_component_dets_to_obs[component_dets] ) + if is_undecomposable: + continue + # Assign observables to each component, such that they are consistent with the # observables of the undecomposed error consistent_obs_by_component = get_component_obs_matching_undecomposed_obs( @@ -213,7 +217,7 @@ def decompose_errors_using_detector_assignment( def decompose_errors_using_detector_coordinate_assignment( dem: stim.DetectorErrorModel, coord_to_component_func: Callable[[list[float]], int], - disable_extra_checks: bool = False, + strip_undecomposable_errors: bool = False, ) -> stim.DetectorErrorModel: """Decomposes errors in the detector error model `dem` based on an assignment of detectors to components using a function of the detector coordinates. @@ -236,10 +240,9 @@ def decompose_errors_using_detector_coordinate_assignment( A function that coordinates of a detector to an integer corresponding to the index of a component, to be used for the decomposition. The coordinates are provided as a list of floats. - disable_extra_checks : bool - If True, decomposition will proceed even if a component of an error is not - present as its own error in the dem. In this case, the component is - assumed to have no observables. + strip_undecomposable_errors : bool + If True, errors that cannot be decomposed due to a missing component error + will be stripped from the output DEM instead of raising a ValueError. Returns ------- @@ -254,7 +257,7 @@ def component_using_coords(detector_id: int) -> int: return decompose_errors_using_detector_assignment( dem=dem, detector_component_func=component_using_coords, - disable_extra_checks=disable_extra_checks, + strip_undecomposable_errors=strip_undecomposable_errors, ) @@ -269,7 +272,7 @@ def detector_coord_to_basis_for_stim_surface_code_convention(coord: tuple[int]) def decompose_errors_using_last_coordinate_index( dem: stim.DetectorErrorModel, - disable_extra_checks: bool = False, + strip_undecomposable_errors: bool = False, ) -> stim.DetectorErrorModel: """Decomposes errors in the detector error model `dem` based on an assignment of detectors to components by the last element of each detector coordinate. @@ -287,10 +290,9 @@ def decompose_errors_using_last_coordinate_index( ---------- dem : stim.DetectorErrorModel The detector error model to decompose. - disable_extra_checks : bool - If True, decomposition will proceed even if a component of an error is not - present as its own error in the dem. In this case, the component is - assumed to have no observables. + strip_undecomposable_errors : bool + If True, errors that cannot be decomposed due to a missing component error + will be stripped from the output DEM instead of raising a ValueError. Returns ------- @@ -305,13 +307,13 @@ def last_coordinate_component(detector_id: int) -> int: return decompose_errors_using_detector_assignment( dem=dem, detector_component_func=last_coordinate_component, - disable_extra_checks=disable_extra_checks, + strip_undecomposable_errors=strip_undecomposable_errors, ) def decompose_errors_for_stim_surface_code_coords( dem: stim.DetectorErrorModel, - disable_extra_checks: bool = False, + strip_undecomposable_errors: bool = False, ) -> stim.DetectorErrorModel: """Decomposes the errors in the dem, such that each component of a decomposed error only triggers detectors of one basis (X or Z) @@ -327,10 +329,9 @@ def decompose_errors_for_stim_surface_code_coords( ---------- dem : stim.DetectorErrorModel The detector error model to decompose - disable_extra_checks : bool - If True, decomposition will proceed even if a component of an error is not - present as its own error in the dem. In this case, the component is - assumed to have no observables. + strip_undecomposable_errors : bool + If True, errors that cannot be decomposed due to a missing component error + will be stripped from the output DEM instead of raising a ValueError. Returns ------- @@ -347,7 +348,7 @@ def stim_surface_code_det_component(detector_id: int) -> int: return decompose_errors_using_detector_assignment( dem=dem, detector_component_func=stim_surface_code_det_component, - disable_extra_checks=disable_extra_checks, + strip_undecomposable_errors=strip_undecomposable_errors, ) diff --git a/src/py/_tesseract_py_util/decompose_errors_test.py b/src/py/_tesseract_py_util/decompose_errors_test.py index 983ef2b..873869a 100644 --- a/src/py/_tesseract_py_util/decompose_errors_test.py +++ b/src/py/_tesseract_py_util/decompose_errors_test.py @@ -185,7 +185,7 @@ def test_undecompose_errors_surface_code(): assert dem_decomposed_using_coords_func == dem_decomposed_using_coords -def test_decompose_errors_disable_extra_checks(): +def test_decompose_errors_strip_undecomposable_errors(): dem = stim.DetectorErrorModel(""" detector(0) D0 detector(1) D1 @@ -200,17 +200,12 @@ def test_decompose_errors_disable_extra_checks(): with pytest.raises(ValueError, match="needs to be decomposed into components"): decompose_errors_using_last_coordinate_index(dem) - # Should pass with disable_extra_checks=True - decomposed_dem = decompose_errors_using_last_coordinate_index(dem, disable_extra_checks=True) + # Should pass with strip_undecomposable_errors=True, but D0 D1 error is removed + decomposed_dem = decompose_errors_using_last_coordinate_index(dem, strip_undecomposable_errors=True) - # Check that D0 D1 was decomposed. - # Since D1 doesn't exist, it should be treated as having no observables. - # D0 exists as error(0.1) D0, so it has no observables either. - # So D0 D1 should decompose to D0 ^ D1. expected_dem = stim.DetectorErrorModel(""" detector(0) D0 detector(1) D1 -error(0.1) D0 ^ D1 error(0.1) D0 """) assert str(decomposed_dem) == str(expected_dem) diff --git a/src/py/_tesseract_py_util/demutil.py b/src/py/_tesseract_py_util/demutil.py index 29a5839..cc9aeee 100644 --- a/src/py/_tesseract_py_util/demutil.py +++ b/src/py/_tesseract_py_util/demutil.py @@ -25,16 +25,16 @@ def decompose_errors( dem: stim.DetectorErrorModel, method: str = "stim-surfacecode-coords", - disable_extra_checks: bool = False, + strip_undecomposable_errors: bool = False, ) -> stim.DetectorErrorModel: """Dispatch decomposition strategy by method name.""" if method == "stim-surfacecode-coords": return decompose_errors_for_stim_surface_code_coords( - dem, disable_extra_checks=disable_extra_checks + dem, strip_undecomposable_errors=strip_undecomposable_errors ) if method == "last-coordinate-index": return decompose_errors_using_last_coordinate_index( - dem, disable_extra_checks=disable_extra_checks + dem, strip_undecomposable_errors=strip_undecomposable_errors ) raise ValueError( "Unknown decomposition method " diff --git a/src/py/_tesseract_py_util/demutil_test.py b/src/py/_tesseract_py_util/demutil_test.py index a4847d1..7aee889 100644 --- a/src/py/_tesseract_py_util/demutil_test.py +++ b/src/py/_tesseract_py_util/demutil_test.py @@ -67,7 +67,7 @@ def test_regeneralize_spatial_dem_averages_template_probabilities(): assert probs == pytest.approx([0.2, 0.3]) -def test_decompose_errors_top_level_disable_extra_checks(): +def test_decompose_errors_top_level_strip_undecomposable_errors(): dem = stim.DetectorErrorModel(""" detector(0) D0 detector(1) D1 @@ -78,15 +78,14 @@ def test_decompose_errors_top_level_disable_extra_checks(): # D1 DOES NOT exist as a standalone error """) - # Should pass with disable_extra_checks=True + # Should pass with strip_undecomposable_errors=True decomposed_dem = demutil.decompose_errors( - dem, method="last-coordinate-index", disable_extra_checks=True + dem, method="last-coordinate-index", strip_undecomposable_errors=True ) expected_dem = stim.DetectorErrorModel(""" detector(0) D0 detector(1) D1 -error(0.1) D0 ^ D1 error(0.1) D0 """) assert str(decomposed_dem) == str(expected_dem) From b2fb2bd4c0edf0089cba0e8b312961d0b1ec1b2b Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Fri, 20 Mar 2026 17:21:45 -0700 Subject: [PATCH 04/10] Update Python README for strip_undecomposable_errors flag --- src/py/README.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/py/README.md b/src/py/README.md index bbe44d5..b3e2aea 100644 --- a/src/py/README.md +++ b/src/py/README.md @@ -574,11 +574,12 @@ print(f"Logical error rate: {result.errors / result.shots}") The `tesseract_decoder.demutil` module provides utilities for manipulating `stim.DetectorErrorModel` objects, specifically for decomposing complex error mechanisms into simpler components and regeneralizing spatial error models. #### Functions -* `demutil.decompose_errors(dem: stim.DetectorErrorModel, method: str) -> stim.DetectorErrorModel` +* `demutil.decompose_errors(dem: stim.DetectorErrorModel, method: str, strip_undecomposable_errors: bool = False) -> stim.DetectorErrorModel` * Decomposes error mechanisms in a DEM into simpler components based on the specified method. * Supported methods: * `"stim-surfacecode-coords"`: Decomposes errors based on the spatial coordinates of detectors, assuming a surface code layout where coordinates indicate X or Z basis. * `"last-coordinate-index"`: Decomposes errors using the last coordinate of the detector as the component identifier. + * `strip_undecomposable_errors`: If `False` (default), raises an error when a complex error cannot be decomposed into known atomic component errors. If `True`, silently drops undecomposable complex errors and continues. * **Note:** For decomposition to work, the DEM must contain "atomic" errors (errors involving only one component) that explain the components of the complex errors. **Example Usage**: @@ -602,6 +603,13 @@ nice_matchable_dem = demutil.decompose_errors(dem, method='stim-surfacecode-coor # Re-decompose the errors assuming the last-coordinate index indicates the component: nice_matchable_dem2 = demutil.decompose_errors(dem, method='last-coordinate-index') + +# Optionally drop undecomposable complex errors instead of raising. +nice_matchable_dem3 = demutil.decompose_errors( + dem, + method='last-coordinate-index', + strip_undecomposable_errors=True, +) ``` * `demutil.regeneralize_spatial_dem(templates: list[stim.DetectorErrorModel], scaffold: stim.DetectorErrorModel, verbose: bool = False) -> stim.DetectorErrorModel` From 37605ef042127799c2393b77d8c99065dfcd24d9 Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Sat, 21 Mar 2026 22:25:30 +0200 Subject: [PATCH 05/10] Clarify early-stop behavior in shot-parallel helper --- src/simplex_main.cc | 126 ++++++++++++++----------------------- src/tesseract_main.cc | 143 ++++++++++++++++-------------------------- src/utils.h | 56 +++++++++++++++++ 3 files changed, 158 insertions(+), 167 deletions(-) diff --git a/src/simplex_main.cc b/src/simplex_main.cc index e8da07d..8f5b555 100644 --- a/src/simplex_main.cc +++ b/src/simplex_main.cc @@ -416,97 +416,65 @@ int main(int argc, char* argv[]) { std::vector shots; std::unique_ptr writer; args.extract(config, shots, writer); - std::atomic next_unclaimed_shot; - std::vector> finished(shots.size()); std::vector obs_predicted(shots.size()); std::vector cost_predicted(shots.size()); std::vector decoding_time_seconds(shots.size()); - std::vector decoder_threads; const stim::DetectorErrorModel original_dem = config.dem.flattened(); std::vector> error_use_totals(original_dem.count_errors()); bool has_obs = args.has_observables(); - std::atomic worker_threads_please_terminate = false; - std::atomic num_worker_threads_active; - for (size_t t = 0; t < args.num_threads; ++t) { - // After this value returns to 0, we know that no further shots will - // transition to finished. - ++num_worker_threads_active; - decoder_threads.push_back(std::thread([&config, &next_unclaimed_shot, &shots, &obs_predicted, - &cost_predicted, &decoding_time_seconds, &finished, - &error_use_totals, &has_obs, - &worker_threads_please_terminate, - &num_worker_threads_active, &original_dem]() { - SimplexDecoder decoder(config); - std::vector error_use(original_dem.count_errors()); - for (size_t shot; - !worker_threads_please_terminate and ((shot = next_unclaimed_shot++) < shots.size());) { + size_t num_errors = 0; + double total_time_seconds = 0; + size_t num_observables = config.dem.count_observables(); + size_t shot = parallel_for_shots_in_order( + shots.size(), args.num_threads, + [&]() { + struct ThreadState { + SimplexDecoder decoder; + std::vector error_use; + explicit ThreadState(const SimplexConfig& config, size_t num_errors) + : decoder(config), error_use(num_errors) {} + }; + return ThreadState(config, original_dem.count_errors()); + }, + [&](auto& thread_state, size_t shot_index) { auto start_time = std::chrono::high_resolution_clock::now(); - decoder.decode_to_errors(shots[shot].hits); + thread_state.decoder.decode_to_errors(shots[shot_index].hits); auto stop_time = std::chrono::high_resolution_clock::now(); - decoding_time_seconds[shot] = + decoding_time_seconds[shot_index] = std::chrono::duration_cast(stop_time - start_time).count() / 1e6; - obs_predicted[shot] = - vector_to_u64_mask(decoder.get_flipped_observables(decoder.predicted_errors_buffer)); - cost_predicted[shot] = decoder.cost_from_errors(decoder.predicted_errors_buffer); - if (!has_obs or shots[shot].obs_mask_as_u64() == obs_predicted[shot]) { - // Only count the error uses for shots that did not have a logical - // error, if we know the obs flips. - for (size_t ei : decoder.predicted_errors_buffer) { - ++error_use[ei]; + obs_predicted[shot_index] = vector_to_u64_mask( + thread_state.decoder.get_flipped_observables(thread_state.decoder.predicted_errors_buffer)); + cost_predicted[shot_index] = + thread_state.decoder.cost_from_errors(thread_state.decoder.predicted_errors_buffer); + if (!has_obs or shots[shot_index].obs_mask_as_u64() == obs_predicted[shot_index]) { + for (size_t ei : thread_state.decoder.predicted_errors_buffer) { + ++thread_state.error_use[ei]; } } - finished[shot] = true; - } - // Add the error counts to the total - for (size_t ei = 0; ei < config.dem.count_errors(); ++ei) { - error_use_totals[ei] += error_use[ei]; - } - --num_worker_threads_active; - })); - } - size_t num_errors = 0; - double total_time_seconds = 0; - size_t num_observables = config.dem.count_observables(); - size_t shot = 0; - for (; shot < shots.size(); ++shot) { - while (num_worker_threads_active and !finished[shot]) { - // We break once the number of active worker threads is 0, at which point - // there will be no further changes to finished[shot]. - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } - // There can be no further changes to finished[shot]. If it is true, we - // process it and go to the next shot. If it is false, we break now as it - // will never be decoded and no subsequent shots will be decoded. - if (!finished[shot]) { - assert(num_worker_threads_active == 0); - // This and subsequent shots will never become decoded. - break; - } - - if (writer) { - writer->write_bits((uint8_t*)&obs_predicted[shot], num_observables); - writer->write_end(); - } - - if (obs_predicted[shot] != shots[shot].obs_mask_as_u64()) ++num_errors; - - total_time_seconds += decoding_time_seconds[shot]; - - if (args.print_stats) { - std::cout << "num_shots = " << (shot + 1) << " num_errors = " << num_errors - << " total_time_seconds = " << total_time_seconds << std::endl; - std::cout << "cost = " << cost_predicted[shot] << std::endl; - std::cout.flush(); - } - - if (num_errors >= args.max_errors) { - worker_threads_please_terminate = true; - } - } - for (size_t t = 0; t < args.num_threads; ++t) { - decoder_threads[t].join(); - } + }, + [&](auto& thread_state) { + for (size_t ei = 0; ei < config.dem.count_errors(); ++ei) { + error_use_totals[ei] += thread_state.error_use[ei]; + } + }, + [&](size_t shot_index) { + if (writer) { + writer->write_bits((uint8_t*)&obs_predicted[shot_index], num_observables); + writer->write_end(); + } + if (obs_predicted[shot_index] != shots[shot_index].obs_mask_as_u64()) { + ++num_errors; + } + total_time_seconds += decoding_time_seconds[shot_index]; + if (args.print_stats) { + std::cout << "num_shots = " << (shot_index + 1) << " num_errors = " << num_errors + << " total_time_seconds = " << total_time_seconds << std::endl; + std::cout << "cost = " << cost_predicted[shot_index] << std::endl; + std::cout.flush(); + } + return num_errors < args.max_errors; + }); if (!args.dem_out_fname.empty()) { std::vector counts(error_use_totals.begin(), error_use_totals.end()); diff --git a/src/tesseract_main.cc b/src/tesseract_main.cc index 65fb4e2..711b13b 100644 --- a/src/tesseract_main.cc +++ b/src/tesseract_main.cc @@ -17,9 +17,9 @@ #include #include #include +#include #include #include -#include #include "common.h" #include "stim.h" @@ -475,105 +475,72 @@ int main(int argc, char* argv[]) { std::vector shots; std::unique_ptr writer; args.extract(config, shots, writer); - std::atomic next_unclaimed_shot; - std::vector> finished(shots.size()); std::vector obs_predicted(shots.size()); std::vector cost_predicted(shots.size()); std::vector decoding_time_seconds(shots.size()); - std::vector> low_confidence(shots.size()); - std::vector decoder_threads; + std::vector low_confidence(shots.size()); const stim::DetectorErrorModel original_dem = config.dem.flattened(); std::vector> error_use_totals(original_dem.count_errors()); bool has_obs = args.has_observables(); - std::atomic worker_threads_please_terminate = false; - std::atomic num_worker_threads_active; - for (size_t t = 0; t < args.num_threads; ++t) { - // After this value returns to 0, we know that no further shots will - // transition to finished. - ++num_worker_threads_active; - decoder_threads.push_back(std::thread([&config, &next_unclaimed_shot, &shots, &obs_predicted, - &cost_predicted, &decoding_time_seconds, &low_confidence, - &finished, &error_use_totals, &has_obs, - &worker_threads_please_terminate, - &num_worker_threads_active, &original_dem]() { - TesseractDecoder decoder(config); - std::vector error_use(original_dem.count_errors()); - for (size_t shot; - !worker_threads_please_terminate and ((shot = next_unclaimed_shot++) < shots.size());) { + size_t num_errors = 0; + size_t num_low_confidence = 0; + double total_time_seconds = 0; + size_t num_observables = config.dem.count_observables(); + size_t shot = parallel_for_shots_in_order( + shots.size(), args.num_threads, + [&]() { + struct ThreadState { + TesseractDecoder decoder; + std::vector error_use; + explicit ThreadState(const TesseractConfig& config, size_t num_errors) + : decoder(config), error_use(num_errors) {} + }; + return ThreadState(config, original_dem.count_errors()); + }, + [&](auto& thread_state, size_t shot_index) { auto start_time = std::chrono::high_resolution_clock::now(); - decoder.decode_to_errors(shots[shot].hits); + thread_state.decoder.decode_to_errors(shots[shot_index].hits); auto stop_time = std::chrono::high_resolution_clock::now(); - decoding_time_seconds[shot] = + decoding_time_seconds[shot_index] = std::chrono::duration_cast(stop_time - start_time).count() / 1e6; - obs_predicted[shot] = - vector_to_u64_mask(decoder.get_flipped_observables(decoder.predicted_errors_buffer)); - low_confidence[shot] = decoder.low_confidence_flag; - cost_predicted[shot] = decoder.cost_from_errors(decoder.predicted_errors_buffer); - if (!has_obs or shots[shot].obs_mask_as_u64() == obs_predicted[shot]) { - // Only count the error uses for shots that did not have a logical - // error, if we know the obs flips. - for (size_t ei : decoder.predicted_errors_buffer) { - ++error_use[ei]; + obs_predicted[shot_index] = vector_to_u64_mask( + thread_state.decoder.get_flipped_observables(thread_state.decoder.predicted_errors_buffer)); + low_confidence[shot_index] = thread_state.decoder.low_confidence_flag; + cost_predicted[shot_index] = + thread_state.decoder.cost_from_errors(thread_state.decoder.predicted_errors_buffer); + if (!has_obs or shots[shot_index].obs_mask_as_u64() == obs_predicted[shot_index]) { + for (size_t ei : thread_state.decoder.predicted_errors_buffer) { + ++thread_state.error_use[ei]; } } - finished[shot] = true; - } - // Add the error counts to the total - for (size_t ei = 0; ei < error_use_totals.size(); ++ei) { - error_use_totals[ei] += error_use[ei]; - } - --num_worker_threads_active; - })); - } - size_t num_errors = 0; - size_t num_low_confidence = 0; - double total_time_seconds = 0; - size_t num_observables = config.dem.count_observables(); - size_t shot = 0; - for (; shot < shots.size(); ++shot) { - while (num_worker_threads_active and !finished[shot]) { - // We break once the number of active worker threads is 0, at which point - // there will be no further changes to finished[shot]. - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } - // There can be no further changes to finished[shot]. If it is true, we - // process it and go to the next shot. If it is false, we break now as it - // will never be decoded and no subsequent shots will be decoded. - if (!finished[shot]) { - assert(num_worker_threads_active == 0); - // This and subsequent shots will never become decoded. - break; - } - - if (writer) { - writer->write_bits((uint8_t*)&obs_predicted[shot], num_observables); - writer->write_end(); - } - - if (low_confidence[shot]) { - ++num_low_confidence; - } else if (obs_predicted[shot] != shots[shot].obs_mask_as_u64()) { - ++num_errors; - } - - total_time_seconds += decoding_time_seconds[shot]; - - if (args.print_stats) { - std::cout << "num_shots = " << (shot + 1) << " num_low_confidence = " << num_low_confidence - << " num_errors = " << num_errors << " total_time_seconds = " << total_time_seconds - << std::endl; - std::cout << "cost = " << cost_predicted[shot] << std::endl; - std::cout.flush(); - } - - if (num_errors >= args.max_errors) { - worker_threads_please_terminate = true; - } - } - for (size_t t = 0; t < args.num_threads; ++t) { - decoder_threads[t].join(); - } + }, + [&](auto& thread_state) { + for (size_t ei = 0; ei < error_use_totals.size(); ++ei) { + error_use_totals[ei] += thread_state.error_use[ei]; + } + }, + [&](size_t shot_index) { + if (writer) { + writer->write_bits((uint8_t*)&obs_predicted[shot_index], num_observables); + writer->write_end(); + } + if (low_confidence[shot_index]) { + ++num_low_confidence; + } else if (obs_predicted[shot_index] != shots[shot_index].obs_mask_as_u64()) { + ++num_errors; + } + total_time_seconds += decoding_time_seconds[shot_index]; + if (args.print_stats) { + std::cout << "num_shots = " << (shot_index + 1) + << " num_low_confidence = " << num_low_confidence + << " num_errors = " << num_errors + << " total_time_seconds = " << total_time_seconds << std::endl; + std::cout << "cost = " << cost_predicted[shot_index] << std::endl; + std::cout.flush(); + } + return num_errors < args.max_errors; + }); if (!args.dem_out_fname.empty()) { std::vector counts(error_use_totals.begin(), error_use_totals.end()); diff --git a/src/utils.h b/src/utils.h index 73d7817..46affeb 100644 --- a/src/utils.h +++ b/src/utils.h @@ -16,10 +16,13 @@ #define __TESSERACT_UTILS_H__ #include +#include #include +#include #include #include #include +#include #include #include @@ -54,4 +57,57 @@ std::vector get_errors_from_dem(const stim::DetectorErrorModel& d std::vector get_files_recursive(const std::string& directory_path); uint64_t vector_to_u64_mask(const std::vector& v); + +// Applies a shot-wise worker function in parallel while consuming completed +// shots in increasing order. If consume_shot returns false, pending workers are +// asked to stop early from claiming new shots, but workers always finish any +// shot they already started. +template +size_t parallel_for_shots_in_order(size_t num_shots, size_t num_threads, + MakeThreadState&& make_thread_state, + ProcessShot&& process_shot, + FinalizeThread&& finalize_thread, + ConsumeShot&& consume_shot) { + std::atomic next_unclaimed_shot = 0; + std::vector> finished(num_shots); + std::atomic worker_threads_please_terminate = false; + std::atomic num_worker_threads_active = 0; + std::vector workers; + workers.reserve(num_threads); + + for (size_t t = 0; t < num_threads; ++t) { + ++num_worker_threads_active; + workers.emplace_back([&, t]() { + auto thread_state = make_thread_state(); + for (size_t shot; !worker_threads_please_terminate && + ((shot = next_unclaimed_shot++) < num_shots);) { + process_shot(thread_state, shot); + finished[shot] = true; + } + finalize_thread(thread_state); + --num_worker_threads_active; + }); + } + + size_t shot = 0; + for (; shot < num_shots; ++shot) { + while (num_worker_threads_active && !finished[shot]) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + if (!finished[shot]) { + assert(num_worker_threads_active == 0); + break; + } + if (!consume_shot(shot)) { + worker_threads_please_terminate = true; + } + } + + for (auto& worker : workers) { + worker.join(); + } + return shot; +} + #endif // __TESSERACT_UTILS_H__ From 3f81b5fade3d3bf644c469e790514ebb74d9cf7b Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Sun, 22 Mar 2026 10:06:03 +0200 Subject: [PATCH 06/10] Fix low_confidence data race in tesseract main --- src/tesseract_main.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tesseract_main.cc b/src/tesseract_main.cc index 711b13b..9ecacda 100644 --- a/src/tesseract_main.cc +++ b/src/tesseract_main.cc @@ -478,7 +478,7 @@ int main(int argc, char* argv[]) { std::vector obs_predicted(shots.size()); std::vector cost_predicted(shots.size()); std::vector decoding_time_seconds(shots.size()); - std::vector low_confidence(shots.size()); + std::vector low_confidence(shots.size()); const stim::DetectorErrorModel original_dem = config.dem.flattened(); std::vector> error_use_totals(original_dem.count_errors()); bool has_obs = args.has_observables(); @@ -506,7 +506,7 @@ int main(int argc, char* argv[]) { 1e6; obs_predicted[shot_index] = vector_to_u64_mask( thread_state.decoder.get_flipped_observables(thread_state.decoder.predicted_errors_buffer)); - low_confidence[shot_index] = thread_state.decoder.low_confidence_flag; + low_confidence[shot_index] = thread_state.decoder.low_confidence_flag ? 1 : 0; cost_predicted[shot_index] = thread_state.decoder.cost_from_errors(thread_state.decoder.predicted_errors_buffer); if (!has_obs or shots[shot_index].obs_mask_as_u64() == obs_predicted[shot_index]) { From 89cc98b9bb4c827fb561734c982cb45b8bef8cbb Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Sun, 22 Mar 2026 10:06:19 +0200 Subject: [PATCH 07/10] Restore atomic low_confidence flags in tesseract --- src/tesseract_main.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tesseract_main.cc b/src/tesseract_main.cc index 9ecacda..70f33d0 100644 --- a/src/tesseract_main.cc +++ b/src/tesseract_main.cc @@ -478,7 +478,7 @@ int main(int argc, char* argv[]) { std::vector obs_predicted(shots.size()); std::vector cost_predicted(shots.size()); std::vector decoding_time_seconds(shots.size()); - std::vector low_confidence(shots.size()); + std::vector> low_confidence(shots.size()); const stim::DetectorErrorModel original_dem = config.dem.flattened(); std::vector> error_use_totals(original_dem.count_errors()); bool has_obs = args.has_observables(); @@ -506,7 +506,7 @@ int main(int argc, char* argv[]) { 1e6; obs_predicted[shot_index] = vector_to_u64_mask( thread_state.decoder.get_flipped_observables(thread_state.decoder.predicted_errors_buffer)); - low_confidence[shot_index] = thread_state.decoder.low_confidence_flag ? 1 : 0; + low_confidence[shot_index] = thread_state.decoder.low_confidence_flag; cost_predicted[shot_index] = thread_state.decoder.cost_from_errors(thread_state.decoder.predicted_errors_buffer); if (!has_obs or shots[shot_index].obs_mask_as_u64() == obs_predicted[shot_index]) { From c2e5c8ee3d07816464c5e4f6338082c66c7569a9 Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Tue, 24 Mar 2026 00:50:06 +0200 Subject: [PATCH 08/10] Simplify shot helper API and use explicit per-thread state --- src/simplex_main.cc | 46 ++++++++++++++++++++--------------------- src/tesseract_main.cc | 48 +++++++++++++++++++++---------------------- src/utils.h | 9 ++------ 3 files changed, 49 insertions(+), 54 deletions(-) diff --git a/src/simplex_main.cc b/src/simplex_main.cc index 8f5b555..e2f4d83 100644 --- a/src/simplex_main.cc +++ b/src/simplex_main.cc @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -420,44 +421,36 @@ int main(int argc, char* argv[]) { std::vector cost_predicted(shots.size()); std::vector decoding_time_seconds(shots.size()); const stim::DetectorErrorModel original_dem = config.dem.flattened(); - std::vector> error_use_totals(original_dem.count_errors()); + std::vector> decoders(args.num_threads); + std::vector> error_use_per_thread( + args.num_threads, std::vector(original_dem.count_errors())); bool has_obs = args.has_observables(); size_t num_errors = 0; double total_time_seconds = 0; size_t num_observables = config.dem.count_observables(); size_t shot = parallel_for_shots_in_order( shots.size(), args.num_threads, - [&]() { - struct ThreadState { - SimplexDecoder decoder; - std::vector error_use; - explicit ThreadState(const SimplexConfig& config, size_t num_errors) - : decoder(config), error_use(num_errors) {} - }; - return ThreadState(config, original_dem.count_errors()); - }, - [&](auto& thread_state, size_t shot_index) { + [&](size_t thread_index, size_t shot_index) { + if (!decoders[thread_index]) { + decoders[thread_index] = std::make_unique(config); + } + auto& decoder = *decoders[thread_index]; + auto& error_use = error_use_per_thread[thread_index]; auto start_time = std::chrono::high_resolution_clock::now(); - thread_state.decoder.decode_to_errors(shots[shot_index].hits); + decoder.decode_to_errors(shots[shot_index].hits); auto stop_time = std::chrono::high_resolution_clock::now(); decoding_time_seconds[shot_index] = std::chrono::duration_cast(stop_time - start_time).count() / 1e6; - obs_predicted[shot_index] = vector_to_u64_mask( - thread_state.decoder.get_flipped_observables(thread_state.decoder.predicted_errors_buffer)); - cost_predicted[shot_index] = - thread_state.decoder.cost_from_errors(thread_state.decoder.predicted_errors_buffer); + obs_predicted[shot_index] = + vector_to_u64_mask(decoder.get_flipped_observables(decoder.predicted_errors_buffer)); + cost_predicted[shot_index] = decoder.cost_from_errors(decoder.predicted_errors_buffer); if (!has_obs or shots[shot_index].obs_mask_as_u64() == obs_predicted[shot_index]) { - for (size_t ei : thread_state.decoder.predicted_errors_buffer) { - ++thread_state.error_use[ei]; + for (size_t ei : decoder.predicted_errors_buffer) { + ++error_use[ei]; } } }, - [&](auto& thread_state) { - for (size_t ei = 0; ei < config.dem.count_errors(); ++ei) { - error_use_totals[ei] += thread_state.error_use[ei]; - } - }, [&](size_t shot_index) { if (writer) { writer->write_bits((uint8_t*)&obs_predicted[shot_index], num_observables); @@ -476,6 +469,13 @@ int main(int argc, char* argv[]) { return num_errors < args.max_errors; }); + std::vector error_use_totals(original_dem.count_errors()); + for (const auto& error_use : error_use_per_thread) { + for (size_t ei = 0; ei < error_use_totals.size(); ++ei) { + error_use_totals[ei] += error_use[ei]; + } + } + if (!args.dem_out_fname.empty()) { std::vector counts(error_use_totals.begin(), error_use_totals.end()); size_t num_usage_dem_shots = shot; diff --git a/src/tesseract_main.cc b/src/tesseract_main.cc index 70f33d0..11d610b 100644 --- a/src/tesseract_main.cc +++ b/src/tesseract_main.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -480,7 +481,9 @@ int main(int argc, char* argv[]) { std::vector decoding_time_seconds(shots.size()); std::vector> low_confidence(shots.size()); const stim::DetectorErrorModel original_dem = config.dem.flattened(); - std::vector> error_use_totals(original_dem.count_errors()); + std::vector> decoders(args.num_threads); + std::vector> error_use_per_thread( + args.num_threads, std::vector(original_dem.count_errors())); bool has_obs = args.has_observables(); size_t num_errors = 0; size_t num_low_confidence = 0; @@ -488,38 +491,28 @@ int main(int argc, char* argv[]) { size_t num_observables = config.dem.count_observables(); size_t shot = parallel_for_shots_in_order( shots.size(), args.num_threads, - [&]() { - struct ThreadState { - TesseractDecoder decoder; - std::vector error_use; - explicit ThreadState(const TesseractConfig& config, size_t num_errors) - : decoder(config), error_use(num_errors) {} - }; - return ThreadState(config, original_dem.count_errors()); - }, - [&](auto& thread_state, size_t shot_index) { + [&](size_t thread_index, size_t shot_index) { + if (!decoders[thread_index]) { + decoders[thread_index] = std::make_unique(config); + } + auto& decoder = *decoders[thread_index]; + auto& error_use = error_use_per_thread[thread_index]; auto start_time = std::chrono::high_resolution_clock::now(); - thread_state.decoder.decode_to_errors(shots[shot_index].hits); + decoder.decode_to_errors(shots[shot_index].hits); auto stop_time = std::chrono::high_resolution_clock::now(); decoding_time_seconds[shot_index] = std::chrono::duration_cast(stop_time - start_time).count() / 1e6; - obs_predicted[shot_index] = vector_to_u64_mask( - thread_state.decoder.get_flipped_observables(thread_state.decoder.predicted_errors_buffer)); - low_confidence[shot_index] = thread_state.decoder.low_confidence_flag; - cost_predicted[shot_index] = - thread_state.decoder.cost_from_errors(thread_state.decoder.predicted_errors_buffer); + obs_predicted[shot_index] = + vector_to_u64_mask(decoder.get_flipped_observables(decoder.predicted_errors_buffer)); + low_confidence[shot_index] = decoder.low_confidence_flag; + cost_predicted[shot_index] = decoder.cost_from_errors(decoder.predicted_errors_buffer); if (!has_obs or shots[shot_index].obs_mask_as_u64() == obs_predicted[shot_index]) { - for (size_t ei : thread_state.decoder.predicted_errors_buffer) { - ++thread_state.error_use[ei]; + for (size_t ei : decoder.predicted_errors_buffer) { + ++error_use[ei]; } } }, - [&](auto& thread_state) { - for (size_t ei = 0; ei < error_use_totals.size(); ++ei) { - error_use_totals[ei] += thread_state.error_use[ei]; - } - }, [&](size_t shot_index) { if (writer) { writer->write_bits((uint8_t*)&obs_predicted[shot_index], num_observables); @@ -542,6 +535,13 @@ int main(int argc, char* argv[]) { return num_errors < args.max_errors; }); + std::vector error_use_totals(original_dem.count_errors()); + for (const auto& error_use : error_use_per_thread) { + for (size_t ei = 0; ei < error_use_totals.size(); ++ei) { + error_use_totals[ei] += error_use[ei]; + } + } + if (!args.dem_out_fname.empty()) { std::vector counts(error_use_totals.begin(), error_use_totals.end()); size_t num_usage_dem_shots = shot; diff --git a/src/utils.h b/src/utils.h index 46affeb..4999f08 100644 --- a/src/utils.h +++ b/src/utils.h @@ -62,12 +62,9 @@ uint64_t vector_to_u64_mask(const std::vector& v); // shots in increasing order. If consume_shot returns false, pending workers are // asked to stop early from claiming new shots, but workers always finish any // shot they already started. -template +template size_t parallel_for_shots_in_order(size_t num_shots, size_t num_threads, - MakeThreadState&& make_thread_state, ProcessShot&& process_shot, - FinalizeThread&& finalize_thread, ConsumeShot&& consume_shot) { std::atomic next_unclaimed_shot = 0; std::vector> finished(num_shots); @@ -79,13 +76,11 @@ size_t parallel_for_shots_in_order(size_t num_shots, size_t num_threads, for (size_t t = 0; t < num_threads; ++t) { ++num_worker_threads_active; workers.emplace_back([&, t]() { - auto thread_state = make_thread_state(); for (size_t shot; !worker_threads_please_terminate && ((shot = next_unclaimed_shot++) < num_shots);) { - process_shot(thread_state, shot); + process_shot(t, shot); finished[shot] = true; } - finalize_thread(thread_state); --num_worker_threads_active; }); } From 9489da9d77d12eb580ca12beaa727c0c6dfff9c5 Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Tue, 24 Mar 2026 01:50:59 +0200 Subject: [PATCH 09/10] Harden thread-count validation and document helper callbacks --- src/simplex_main.cc | 7 ++++++- src/tesseract_main.cc | 7 ++++++- src/utils.h | 14 +++++++++++--- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/simplex_main.cc b/src/simplex_main.cc index e2f4d83..63f0075 100644 --- a/src/simplex_main.cc +++ b/src/simplex_main.cc @@ -108,6 +108,9 @@ struct Args { "Cannot load observable flips without a corresponding detection " "event data file."); } + if (num_threads == 0) { + throw std::invalid_argument("--threads must be at least 1."); + } if (num_threads > 1000) { throw std::invalid_argument( "There is a maximum limit of 1000 threads imposed to avoid " @@ -368,7 +371,9 @@ int main(int argc, char* argv[]) { program.add_argument("--threads") .help("Number of decoder threads to use") .metavar("N") - .default_value(size_t(std::thread::hardware_concurrency())) + .default_value(size_t(std::thread::hardware_concurrency() == 0 + ? 1 + : std::thread::hardware_concurrency())) .store_into(args.num_threads); program.add_argument("--parallelize-ilp") .help( diff --git a/src/tesseract_main.cc b/src/tesseract_main.cc index 11d610b..c068d9c 100644 --- a/src/tesseract_main.cc +++ b/src/tesseract_main.cc @@ -121,6 +121,9 @@ struct Args { "Cannot load observable flips without a corresponding detection " "event data file."); } + if (num_threads == 0) { + throw std::invalid_argument("--threads must be at least 1."); + } if (num_threads > 1000) { throw std::invalid_argument( "There is a maximum limit of 1000 threads imposed to avoid " @@ -425,7 +428,9 @@ int main(int argc, char* argv[]) { program.add_argument("--threads") .help("Number of decoder threads to use") .metavar("N") - .default_value(size_t(std::thread::hardware_concurrency())) + .default_value(size_t(std::thread::hardware_concurrency() == 0 + ? 1 + : std::thread::hardware_concurrency())) .store_into(args.num_threads); program.add_argument("--beam") .help("Beam to use for truncation (default = infinity)") diff --git a/src/utils.h b/src/utils.h index 4999f08..48ecfad 100644 --- a/src/utils.h +++ b/src/utils.h @@ -59,9 +59,17 @@ std::vector get_files_recursive(const std::string& directory_path); uint64_t vector_to_u64_mask(const std::vector& v); // Applies a shot-wise worker function in parallel while consuming completed -// shots in increasing order. If consume_shot returns false, pending workers are -// asked to stop early from claiming new shots, but workers always finish any -// shot they already started. +// shots in increasing order. +// +// process_shot(thread_index, shot_index): +// - Runs on worker threads. +// - thread_index is stable for each worker and lies in [0, num_threads). +// +// consume_shot(shot_index): +// - Runs on the caller thread in increasing shot order. +// +// If consume_shot returns false, workers stop claiming new shots but always +// finish any shot they already started. template size_t parallel_for_shots_in_order(size_t num_shots, size_t num_threads, ProcessShot&& process_shot, From 92dcdba027d523236ff55c77d6d5493138c9dd18 Mon Sep 17 00:00:00 2001 From: noajshu Date: Tue, 24 Mar 2026 02:29:30 +0000 Subject: [PATCH 10/10] ran clang-format --- src/simplex_main.cc | 5 ++--- src/tesseract_main.cc | 7 +++---- src/utils.h | 7 +++---- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/simplex_main.cc b/src/simplex_main.cc index 63f0075..7939a91 100644 --- a/src/simplex_main.cc +++ b/src/simplex_main.cc @@ -371,9 +371,8 @@ int main(int argc, char* argv[]) { program.add_argument("--threads") .help("Number of decoder threads to use") .metavar("N") - .default_value(size_t(std::thread::hardware_concurrency() == 0 - ? 1 - : std::thread::hardware_concurrency())) + .default_value(size_t( + std::thread::hardware_concurrency() == 0 ? 1 : std::thread::hardware_concurrency())) .store_into(args.num_threads); program.add_argument("--parallelize-ilp") .help( diff --git a/src/tesseract_main.cc b/src/tesseract_main.cc index c068d9c..ab5ed9c 100644 --- a/src/tesseract_main.cc +++ b/src/tesseract_main.cc @@ -18,9 +18,9 @@ #include #include #include -#include #include #include +#include #include "common.h" #include "stim.h" @@ -428,9 +428,8 @@ int main(int argc, char* argv[]) { program.add_argument("--threads") .help("Number of decoder threads to use") .metavar("N") - .default_value(size_t(std::thread::hardware_concurrency() == 0 - ? 1 - : std::thread::hardware_concurrency())) + .default_value(size_t( + std::thread::hardware_concurrency() == 0 ? 1 : std::thread::hardware_concurrency())) .store_into(args.num_threads); program.add_argument("--beam") .help("Beam to use for truncation (default = infinity)") diff --git a/src/utils.h b/src/utils.h index 48ecfad..fe89b4f 100644 --- a/src/utils.h +++ b/src/utils.h @@ -71,8 +71,7 @@ uint64_t vector_to_u64_mask(const std::vector& v); // If consume_shot returns false, workers stop claiming new shots but always // finish any shot they already started. template -size_t parallel_for_shots_in_order(size_t num_shots, size_t num_threads, - ProcessShot&& process_shot, +size_t parallel_for_shots_in_order(size_t num_shots, size_t num_threads, ProcessShot&& process_shot, ConsumeShot&& consume_shot) { std::atomic next_unclaimed_shot = 0; std::vector> finished(num_shots); @@ -84,8 +83,8 @@ size_t parallel_for_shots_in_order(size_t num_shots, size_t num_threads, for (size_t t = 0; t < num_threads; ++t) { ++num_worker_threads_active; workers.emplace_back([&, t]() { - for (size_t shot; !worker_threads_please_terminate && - ((shot = next_unclaimed_shot++) < num_shots);) { + for (size_t shot; + !worker_threads_please_terminate && ((shot = next_unclaimed_shot++) < num_shots);) { process_shot(t, shot); finished[shot] = true; }