Skip to content

Commit a536c12

Browse files
committed
optimize stab
1 parent 1df4bb0 commit a536c12

3 files changed

Lines changed: 238 additions & 12 deletions

File tree

python/pecos-rslib/src/simulator_utils.rs

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,193 @@ pub fn register_simulator_utils(m: &Bound<'_, PyModule>) -> PyResult<()> {
221221
m.add_class::<TableauWrapper>()?;
222222
Ok(())
223223
}
224+
225+
// --- Shared batch dispatch for simulator bindings ---
226+
227+
use pecos::core::QubitId;
228+
use pecos::simulators::{CliffordGateable, MeasurementResult};
229+
use pyo3::types::{PySet, PyTuple};
230+
231+
/// Extract a single qubit index from a Python location.
232+
/// Handles both bare ints and 1-tuples like `(0,)` (the `GateBindingsDict` wraps ints in tuples).
233+
pub fn extract_single_qubit(location: &Bound<'_, PyAny>) -> PyResult<usize> {
234+
if let Ok(q) = location.extract::<usize>() {
235+
return Ok(q);
236+
}
237+
if let Ok(tuple) = location.downcast::<PyTuple>() {
238+
if tuple.len() == 1 {
239+
return tuple.get_item(0)?.extract::<usize>();
240+
}
241+
}
242+
Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!(
243+
"Expected int or 1-tuple for single-qubit location, got {:?}",
244+
location.get_type().name()?
245+
)))
246+
}
247+
248+
/// Collect single-qubit locations from a Python set into a Vec of QubitIds.
249+
fn collect_single_qubits(locations: &Bound<'_, PySet>) -> PyResult<Vec<QubitId>> {
250+
locations
251+
.iter()
252+
.map(|l| Ok(QubitId(extract_single_qubit(&l)?)))
253+
.collect()
254+
}
255+
256+
/// Collect single-qubit locations as raw usize values.
257+
fn collect_single_qubit_indices(locations: &Bound<'_, PySet>) -> PyResult<Vec<usize>> {
258+
locations
259+
.iter()
260+
.map(|l| extract_single_qubit(&l))
261+
.collect()
262+
}
263+
264+
/// Collect two-qubit pair locations from a Python set.
265+
fn collect_pairs(locations: &Bound<'_, PySet>) -> PyResult<Vec<(QubitId, QubitId)>> {
266+
locations
267+
.iter()
268+
.map(|l| {
269+
let t: (usize, usize) = l.extract()?;
270+
Ok((QubitId(t.0), QubitId(t.1)))
271+
})
272+
.collect()
273+
}
274+
275+
/// Build a measurement output dict from qubit indices and results.
276+
fn build_meas_output(
277+
py: Python<'_>,
278+
qubits: &[usize],
279+
results: Vec<MeasurementResult>,
280+
) -> PyResult<Py<PyDict>> {
281+
let output = PyDict::new(py);
282+
for (&q, r) in qubits.iter().zip(results) {
283+
if r.outcome {
284+
output.set_item(q, 1u8)?;
285+
}
286+
}
287+
Ok(output.into())
288+
}
289+
290+
/// Try to dispatch a gate in batch mode for any `CliffordGateable` simulator.
291+
///
292+
/// Returns `Some(output_dict)` if the gate was handled, `None` to fall back to
293+
/// per-location dispatch (for parameterized gates, unknown symbols, etc.).
294+
pub fn try_clifford_batch_dispatch<S: CliffordGateable>(
295+
sim: &mut S,
296+
symbol: &str,
297+
locations: &Bound<'_, PySet>,
298+
py: Python<'_>,
299+
) -> PyResult<Option<Py<PyDict>>> {
300+
match symbol {
301+
// Identity
302+
"I" => return Ok(Some(PyDict::new(py).into())),
303+
304+
// Single-qubit Clifford gates (no return value)
305+
"X" | "Y" | "Z" | "H" | "H1" | "H+z+x" | "H2" | "H-z-x" | "H3" | "H+y-z"
306+
| "H4" | "H-y-z" | "H5" | "H-x+y" | "H6" | "H-x-y" | "F" | "F1" | "Fdg"
307+
| "F1d" | "F1dg" | "F2" | "F2dg" | "F2d" | "F3" | "F3dg" | "F3d" | "F4"
308+
| "F4dg" | "F4d" | "Q" | "SX" | "SqrtX" | "Qd" | "SXdg" | "SqrtXd" | "SqrtXdg"
309+
| "R" | "SY" | "SqrtY" | "Rd" | "SYdg" | "SqrtYd" | "SqrtYdg" | "S" | "SZ"
310+
| "SqrtZ" | "Sd" | "SZdg" | "SqrtZd" | "SqrtZdg" => {
311+
let qubits = collect_single_qubits(locations)?;
312+
match symbol {
313+
"X" => { sim.x(&qubits); }
314+
"Y" => { sim.y(&qubits); }
315+
"Z" => { sim.z(&qubits); }
316+
"H" | "H1" | "H+z+x" => { sim.h(&qubits); }
317+
"H2" | "H-z-x" => { sim.h2(&qubits); }
318+
"H3" | "H+y-z" => { sim.h3(&qubits); }
319+
"H4" | "H-y-z" => { sim.h4(&qubits); }
320+
"H5" | "H-x+y" => { sim.h5(&qubits); }
321+
"H6" | "H-x-y" => { sim.h6(&qubits); }
322+
"F" | "F1" => { sim.f(&qubits); }
323+
"Fdg" | "F1d" | "F1dg" => { sim.fdg(&qubits); }
324+
"F2" => { sim.f2(&qubits); }
325+
"F2dg" | "F2d" => { sim.f2dg(&qubits); }
326+
"F3" => { sim.f3(&qubits); }
327+
"F3dg" | "F3d" => { sim.f3dg(&qubits); }
328+
"F4" => { sim.f4(&qubits); }
329+
"F4dg" | "F4d" => { sim.f4dg(&qubits); }
330+
"Q" | "SX" | "SqrtX" => { sim.sx(&qubits); }
331+
"Qd" | "SXdg" | "SqrtXd" | "SqrtXdg" => { sim.sxdg(&qubits); }
332+
"R" | "SY" | "SqrtY" => { sim.sy(&qubits); }
333+
"Rd" | "SYdg" | "SqrtYd" | "SqrtYdg" => { sim.sydg(&qubits); }
334+
"S" | "SZ" | "SqrtZ" => { sim.sz(&qubits); }
335+
"Sd" | "SZdg" | "SqrtZd" | "SqrtZdg" => { sim.szdg(&qubits); }
336+
_ => unreachable!(),
337+
}
338+
return Ok(Some(PyDict::new(py).into()));
339+
}
340+
341+
// Preparations (no return value)
342+
"PZ" | "Init" | "Init +Z" | "init |0>" | "leak" | "leak |0>" | "unleak |0>" => {
343+
sim.pz(&collect_single_qubits(locations)?);
344+
return Ok(Some(PyDict::new(py).into()));
345+
}
346+
"PnZ" | "Init -Z" | "init |1>" | "leak |1>" | "unleak |1>" => {
347+
sim.pnz(&collect_single_qubits(locations)?);
348+
return Ok(Some(PyDict::new(py).into()));
349+
}
350+
"PX" | "Init +X" | "init |+>" => {
351+
sim.px(&collect_single_qubits(locations)?);
352+
return Ok(Some(PyDict::new(py).into()));
353+
}
354+
"PnX" | "Init -X" | "init |->" => {
355+
sim.pnx(&collect_single_qubits(locations)?);
356+
return Ok(Some(PyDict::new(py).into()));
357+
}
358+
"PY" | "Init +Y" | "init |+i>" => {
359+
sim.py(&collect_single_qubits(locations)?);
360+
return Ok(Some(PyDict::new(py).into()));
361+
}
362+
"PnY" | "Init -Y" | "init |-i>" => {
363+
sim.pny(&collect_single_qubits(locations)?);
364+
return Ok(Some(PyDict::new(py).into()));
365+
}
366+
367+
// Measurements (return outcomes)
368+
"MZ" | "Measure" | "measure Z" | "Measure +Z" => {
369+
let qubits = collect_single_qubit_indices(locations)?;
370+
let qubit_ids: Vec<QubitId> = qubits.iter().map(|&q| QubitId(q)).collect();
371+
let results = sim.mz(&qubit_ids);
372+
return Ok(Some(build_meas_output(py, &qubits, results)?));
373+
}
374+
"MX" | "Measure +X" => {
375+
let qubits = collect_single_qubit_indices(locations)?;
376+
let qubit_ids: Vec<QubitId> = qubits.iter().map(|&q| QubitId(q)).collect();
377+
let results = sim.mx(&qubit_ids);
378+
return Ok(Some(build_meas_output(py, &qubits, results)?));
379+
}
380+
"MY" | "Measure +Y" => {
381+
let qubits = collect_single_qubit_indices(locations)?;
382+
let qubit_ids: Vec<QubitId> = qubits.iter().map(|&q| QubitId(q)).collect();
383+
let results = sim.my(&qubit_ids);
384+
return Ok(Some(build_meas_output(py, &qubits, results)?));
385+
}
386+
387+
// Two-qubit Clifford gates (no return value)
388+
"CX" | "CNOT" | "CY" | "CZ" | "SZZ" | "SZZdg" | "SXX" | "SXXdg" | "SYY"
389+
| "SYYdg" | "SqrtZZ" | "SqrtZZd" | "SqrtXX" | "SqrtXXd" | "SqrtYY" | "SqrtYYd"
390+
| "SWAP" | "G" | "G2" => {
391+
let pairs = collect_pairs(locations)?;
392+
match symbol {
393+
"CX" | "CNOT" => { sim.cx(&pairs); }
394+
"CY" => { sim.cy(&pairs); }
395+
"CZ" => { sim.cz(&pairs); }
396+
"SZZ" | "SqrtZZ" => { sim.szz(&pairs); }
397+
"SZZdg" | "SqrtZZd" => { sim.szzdg(&pairs); }
398+
"SXX" | "SqrtXX" => { sim.sxx(&pairs); }
399+
"SXXdg" | "SqrtXXd" => { sim.sxxdg(&pairs); }
400+
"SYY" | "SqrtYY" => { sim.syy(&pairs); }
401+
"SYYdg" | "SqrtYYd" => { sim.syydg(&pairs); }
402+
"SWAP" => { sim.swap(&pairs); }
403+
"G" | "G2" => { sim.g(&pairs); }
404+
_ => unreachable!(),
405+
}
406+
return Ok(Some(PyDict::new(py).into()));
407+
}
408+
409+
_ => {}
410+
}
411+
412+
Ok(None)
413+
}

python/pecos-rslib/src/sparse_stab_bindings.rs

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,12 @@ impl PySparseSim {
427427
}
428428
}
429429

430-
/// High-level `run_gate` method that accepts a set of locations
430+
/// High-level `run_gate` method that accepts a set of locations.
431+
///
432+
/// For common gates without special parameters, collects all locations and
433+
/// dispatches one batched call to the simulator instead of per-location calls.
434+
/// This avoids per-location Python↔Rust overhead and enables simulator-level
435+
/// batch optimizations (gate fusion, joint measurement sampling, etc.).
431436
#[pyo3(signature = (symbol, locations, **params))]
432437
fn run_gate_highlevel(
433438
&mut self,
@@ -446,22 +451,37 @@ impl PySparseSim {
446451
return Ok(output.into());
447452
}
448453

449-
// Convert locations to a vector
450454
let locations_set: Bound<PySet> = locations.clone().cast_into()?;
455+
if locations_set.is_empty() {
456+
return Ok(output.into());
457+
}
458+
459+
// Check if params have special keys that require per-location dispatch
460+
// Gates with special params need per-location dispatch (forced outcomes,
461+
// rotation angles, conditional execution, etc.)
462+
let has_special_params = params.is_some_and(|p| !p.is_empty());
463+
464+
// Fast path: batch dispatch for common gates without special params
465+
if !has_special_params {
466+
if let Some(result) =
467+
crate::simulator_utils::try_clifford_batch_dispatch(
468+
&mut self.inner, symbol, &locations_set, py,
469+
)?
470+
{
471+
return Ok(result);
472+
}
473+
}
451474

475+
// Fallback: per-location dispatch for parameterized/special gates
452476
for location in locations_set.iter() {
453-
// Convert location to tuple
454477
let loc_tuple: Bound<'_, PyTuple> = if location.is_instance_of::<PyTuple>() {
455478
location.clone().cast_into()?
456479
} else {
457-
// Single qubit - wrap in tuple
458480
PyTuple::new(py, std::slice::from_ref(&location))?
459481
};
460482

461-
// Call the underlying run_gate_internal
462483
let result = self.run_gate_internal(symbol, &loc_tuple, params)?;
463484

464-
// Only add to output if result is Some (non-zero measurement)
465485
if let Some(value) = result {
466486
output.set_item(location, value)?;
467487
}
@@ -470,6 +490,8 @@ impl PySparseSim {
470490
Ok(output.into())
471491
}
472492

493+
// try_batch_dispatch is now shared via crate::simulator_utils::try_clifford_batch_dispatch
494+
473495
/// Execute a quantum circuit
474496
#[pyo3(signature = (circuit, removed_locations=None))]
475497
fn run_circuit(

python/pecos-rslib/src/stab_bindings.rs

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,10 @@ impl PyStab {
379379
self.run_gate_highlevel(symbol, locations, params, py)
380380
}
381381

382-
/// High-level `run_gate` method that accepts a set of locations
382+
/// High-level `run_gate` method that accepts a set of locations.
383+
///
384+
/// Uses shared batch dispatch for common Clifford gates to avoid per-location
385+
/// overhead. Falls back to per-location dispatch for parameterized gates.
383386
#[pyo3(signature = (symbol, locations, **params))]
384387
fn run_gate_highlevel(
385388
&mut self,
@@ -398,22 +401,33 @@ impl PyStab {
398401
return Ok(output.into());
399402
}
400403

401-
// Convert locations to a vector
402404
let locations_set: Bound<PySet> = locations.clone().cast_into()?;
405+
if locations_set.is_empty() {
406+
return Ok(output.into());
407+
}
408+
409+
// Fast path: batch dispatch for common gates without special params
410+
let has_special_params = params.is_some_and(|p| !p.is_empty());
411+
if !has_special_params {
412+
if let Some(result) =
413+
crate::simulator_utils::try_clifford_batch_dispatch(
414+
&mut self.inner, symbol, &locations_set, py,
415+
)?
416+
{
417+
return Ok(result);
418+
}
419+
}
403420

421+
// Fallback: per-location dispatch
404422
for location in locations_set.iter() {
405-
// Convert location to tuple
406423
let loc_tuple: Bound<'_, PyTuple> = if location.is_instance_of::<PyTuple>() {
407424
location.clone().cast_into()?
408425
} else {
409-
// Single qubit - wrap in tuple
410426
PyTuple::new(py, std::slice::from_ref(&location))?
411427
};
412428

413-
// Call the underlying run_gate_internal
414429
let result = self.run_gate_internal(symbol, &loc_tuple, params)?;
415430

416-
// Only add to output if result is Some (non-zero measurement)
417431
if let Some(value) = result {
418432
output.set_item(location, value)?;
419433
}

0 commit comments

Comments
 (0)