Skip to content

Commit cff2fc5

Browse files
committed
update and add tests
Signed-off-by: Kaiqi Yan <kaiqiy@nvidia.com>
1 parent a14212f commit cff2fc5

1 file changed

Lines changed: 13 additions & 14 deletions

File tree

libs/qec/python/tests/test_dem.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,10 @@ def _stim_dem_to_arrays(stim_dem):
189189
h_cols.append(h_col)
190190
o_cols.append(o_col)
191191
rates.append(prob)
192-
H = np.stack(h_cols, axis=1) if h_cols else np.zeros((n_dets, 0),
193-
dtype=np.uint8)
194-
O = np.stack(o_cols, axis=1) if o_cols else np.zeros((n_obs, 0),
195-
dtype=np.uint8)
192+
H = np.stack(h_cols, axis=1) if h_cols else np.zeros(
193+
(n_dets, 0), dtype=np.uint8)
194+
O = np.stack(o_cols, axis=1) if o_cols else np.zeros(
195+
(n_obs, 0), dtype=np.uint8)
196196
return H, O, np.asarray(rates, dtype=np.float64)
197197

198198

@@ -229,8 +229,8 @@ def test_z_dem_from_memory_circuit_against_stim_oracle():
229229
"""Compare cudaq-qec's Steane Z-DEM against Stim's independent DEM."""
230230
stim_mod = pytest.importorskip(
231231
"stim",
232-
reason="stim not installed; skipping Stim oracle cross-check for Steane Z-DEM"
233-
)
232+
reason=
233+
"stim not installed; skipping Stim oracle cross-check for Steane Z-DEM")
234234

235235
code = qec.get_code('steane')
236236
p = 0.01
@@ -267,10 +267,10 @@ def test_z_dem_from_memory_circuit_against_stim_oracle():
267267

268268
# The small tolerance covers different grouping of Pauli outcomes.
269269
for k in cudaq_keys:
270-
assert np.isclose(cudaq_terms[k], stim_terms[k], atol=1e-4,
271-
rtol=1e-3), (
272-
f"probability mismatch at {k}: "
273-
f"cudaq={cudaq_terms[k]}, stim={stim_terms[k]}")
270+
assert np.isclose(
271+
cudaq_terms[k], stim_terms[k], atol=1e-4,
272+
rtol=1e-3), (f"probability mismatch at {k}: "
273+
f"cudaq={cudaq_terms[k]}, stim={stim_terms[k]}")
274274

275275

276276
def test_x_dem_from_memory_circuit():
@@ -486,8 +486,7 @@ def test_pymatching_decodes_stim_surface_code_dem():
486486
"""Decode a Stim surface-code DEM through cudaq-qec's PyMatching plugin."""
487487
stim_mod = pytest.importorskip(
488488
"stim",
489-
reason="stim not installed; skipping Stim-based PyMatching decode test"
490-
)
489+
reason="stim not installed; skipping Stim-based PyMatching decode test")
491490

492491
distance = 5
493492
n_rounds = 5
@@ -533,8 +532,8 @@ def test_pymatching_decodes_stim_surface_code_dem():
533532
data_predictions = np.round(obs_per_shot).astype(np.uint8).flatten()
534533

535534
n_errors_without_decoding = int(np.sum(logical_measurements))
536-
n_errors_with_decoding = int(
537-
np.sum(data_predictions ^ logical_measurements))
535+
n_errors_with_decoding = int(np.sum(data_predictions ^
536+
logical_measurements))
538537
assert n_errors_with_decoding < n_errors_without_decoding, (
539538
f"PyMatching did not reduce logical errors: "
540539
f"with_decoding={n_errors_with_decoding}, "

0 commit comments

Comments
 (0)