@@ -18,40 +18,55 @@ namespace cudaq::qec {
1818
1919namespace {
2020
21+ // std::mutex is enough: the factory copies the creator out before
22+ // invoking, so creators cannot re-enter the registry.
2123struct stim_dem_registry {
22- std::recursive_mutex &mutex;
24+ std::mutex &mutex;
2325 std::unordered_map<std::string, stim_dem_decoder_creator> ↦
2426};
2527stim_dem_registry get_stim_dem_registry () {
26- static std::recursive_mutex *mutex = new std::recursive_mutex ();
28+ // Heap-allocated to outlive static destructors (plugin dlclose unregister
29+ // path); matches the cudaqx extension_point pattern. See extension_point.h.
30+ static std::mutex *mutex = new std::mutex ();
2731 static auto *map =
2832 new std::unordered_map<std::string, stim_dem_decoder_creator>();
2933 return {*mutex, *map};
3034}
3135
3236} // namespace
3337
38+ dem_default_values dem_defaults_for_missing_keys (
39+ const std::function<bool (const std::string &)> &contains_user_key,
40+ const detector_error_model &dem) {
41+ dem_default_values out;
42+ if (!contains_user_key (" O" ))
43+ out.O = &dem.observables_flips_matrix ;
44+ if (!contains_user_key (" error_rate_vec" ))
45+ out.error_rate_vec = &dem.error_rates ;
46+ return out;
47+ }
48+
3449void register_stim_dem_decoder_creator (const std::string &name,
3550 stim_dem_decoder_creator creator) {
3651 auto reg = get_stim_dem_registry ();
37- std::lock_guard<std::recursive_mutex > lock (reg.mutex );
52+ std::lock_guard<std::mutex > lock (reg.mutex );
3853 reg.map [name] = std::move (creator);
3954}
4055
4156void unregister_stim_dem_decoder_creator (const std::string &name) {
4257 auto reg = get_stim_dem_registry ();
43- std::lock_guard<std::recursive_mutex > lock (reg.mutex );
58+ std::lock_guard<std::mutex > lock (reg.mutex );
4459 reg.map .erase (name);
4560}
4661
4762std::unique_ptr<decoder>
4863get_decoder_from_stim_dem (const std::string &name,
4964 const std::string &stim_dem_text,
50- const cudaqx::heterogeneous_map options) {
65+ const cudaqx::heterogeneous_map & options) {
5166 stim_dem_decoder_creator creator;
5267 {
5368 auto reg = get_stim_dem_registry ();
54- std::lock_guard<std::recursive_mutex > lock (reg.mutex );
69+ std::lock_guard<std::mutex > lock (reg.mutex );
5570 auto iter = reg.map .find (name);
5671 if (iter != reg.map .end ())
5772 creator = iter->second ;
@@ -68,10 +83,13 @@ get_decoder_from_stim_dem(const std::string &name,
6883 auto dem = dem_from_stim_text (stim_dem_text);
6984
7085 cudaqx::heterogeneous_map merged = options;
71- if (!merged.contains (" O" ))
72- merged.insert (" O" , dem.observables_flips_matrix );
73- if (!merged.contains (" error_rate_vec" ))
74- merged.insert (" error_rate_vec" , dem.error_rates );
86+ // Keep in sync with the Python binding in py_decoder.cpp.
87+ auto defaults = dem_defaults_for_missing_keys (
88+ [&](const std::string &key) { return merged.contains (key); }, dem);
89+ if (defaults.O )
90+ merged.insert (" O" , *defaults.O );
91+ if (defaults.error_rate_vec )
92+ merged.insert (" error_rate_vec" , *defaults.error_rate_vec );
7593
7694 return decoder::get (name, dem.detector_error_matrix , merged);
7795}
0 commit comments