Skip to content

Commit 1df4bb0

Browse files
committed
optimization
1 parent 9dd1a34 commit 1df4bb0

4 files changed

Lines changed: 828 additions & 100 deletions

File tree

crates/benchmarks/benches/modules/state_vec_sims.rs

Lines changed: 101 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ pub fn benchmarks<M: Measurement>(c: &mut Criterion<M>) {
7171
bench_state_vec_scaling(c);
7272
bench_individual_gates(c);
7373
bench_measurement_scaling(c);
74+
bench_subset_measurement(c);
75+
bench_flush_scaling(c);
7476
#[cfg(feature = "parallel")]
7577
bench_parallel_execution(c);
7678
}
@@ -241,19 +243,20 @@ fn bench_individual_gates<M: Measurement>(c: &mut Criterion<M>) {
241243
group.finish();
242244
}
243245

244-
/// Benchmark measurement performance scaling across qubit counts.
245-
/// Measures all qubits after applying H to each (maximum uncertainty).
246-
/// This isolates the GPU measurement optimization (workgroup reduction vs full readback).
246+
/// Benchmark measurement performance: sequential (per-qubit) vs batch (all at once).
247+
///
248+
/// Sequential: `for q in 0..n { sim.mz(&[QubitId(q)]); }` — 2n passes over state vector.
249+
/// Batch: `sim.mz(&all_qubits)` — uses joint sampling, 2 passes over state vector.
247250
fn bench_measurement_scaling<M: Measurement>(c: &mut Criterion<M>) {
248251
let mut group = c.benchmark_group("Measurement Scaling");
249252
group.sample_size(20);
250253

251254
let qubit_counts = [10, 14, 18, 20, 22];
252255

253256
for &nq in &qubit_counts {
254-
// CPU baseline: StateVec
257+
// Sequential: one mz() call per qubit (2n passes)
255258
group.bench_with_input(
256-
BenchmarkId::new("StateVec_CPU", nq),
259+
BenchmarkId::new("mz_sequential", nq),
257260
&nq,
258261
|b, &nq| {
259262
let mut sim = StateVecSoA::new(nq);
@@ -269,7 +272,24 @@ fn bench_measurement_scaling<M: Measurement>(c: &mut Criterion<M>) {
269272
},
270273
);
271274

272-
// GPU: GpuStateVec (wgpu)
275+
// Batch: one mz() call with all qubits (2 passes via joint sampling)
276+
group.bench_with_input(
277+
BenchmarkId::new("mz_batch", nq),
278+
&nq,
279+
|b, &nq| {
280+
let mut sim = StateVecSoA::new(nq);
281+
let all_qubits: Vec<QubitId> = (0..nq).map(QubitId).collect();
282+
b.iter(|| {
283+
sim.reset();
284+
for q in 0..nq {
285+
sim.h(&[QubitId(q)]);
286+
}
287+
black_box(sim.mz(&all_qubits));
288+
});
289+
},
290+
);
291+
292+
// GPU (sequential per-qubit for comparison)
273293
#[cfg(feature = "gpu-sims")]
274294
{
275295
#[allow(clippy::cast_possible_truncation)]
@@ -296,6 +316,81 @@ fn bench_measurement_scaling<M: Measurement>(c: &mut Criterion<M>) {
296316
group.finish();
297317
}
298318

319+
/// Benchmark subset measurement: measure half the qubits (even-indexed).
320+
/// Tests the mz_joint_subset path (QEC-realistic: measure ancillas, not data qubits).
321+
fn bench_subset_measurement<M: Measurement>(c: &mut Criterion<M>) {
322+
let mut group = c.benchmark_group("Subset Measurement");
323+
group.sample_size(20);
324+
325+
let qubit_counts = [10, 14, 18, 20, 22];
326+
327+
for &nq in &qubit_counts {
328+
let half: Vec<QubitId> = (0..nq).step_by(2).map(QubitId).collect();
329+
let half_count = half.len();
330+
331+
// Sequential: one mz() per qubit
332+
group.bench_with_input(
333+
BenchmarkId::new("mz_sequential", format!("{nq}q_{half_count}m")),
334+
&nq,
335+
|b, &nq| {
336+
let mut sim = StateVecSoA::new(nq);
337+
b.iter(|| {
338+
sim.reset();
339+
for q in 0..nq {
340+
sim.h(&[QubitId(q)]);
341+
}
342+
for &q in &half {
343+
black_box(sim.mz(&[q]));
344+
}
345+
});
346+
},
347+
);
348+
349+
// Batch: one mz() with all measured qubits
350+
group.bench_with_input(
351+
BenchmarkId::new("mz_batch_subset", format!("{nq}q_{half_count}m")),
352+
&nq,
353+
|b, &nq| {
354+
let mut sim = StateVecSoA::new(nq);
355+
b.iter(|| {
356+
sim.reset();
357+
for q in 0..nq {
358+
sim.h(&[QubitId(q)]);
359+
}
360+
black_box(sim.mz(&half));
361+
});
362+
},
363+
);
364+
}
365+
366+
group.finish();
367+
}
368+
369+
/// Benchmark flush performance: H on all qubits then flush.
370+
/// Isolates the cache-blocked flush optimization from measurement.
371+
fn bench_flush_scaling<M: Measurement>(c: &mut Criterion<M>) {
372+
let mut group = c.benchmark_group("Flush Scaling");
373+
group.sample_size(20);
374+
375+
let qubit_counts = [14, 18, 20, 22];
376+
377+
for &nq in &qubit_counts {
378+
group.bench_with_input(BenchmarkId::new("h_all_flush", nq), &nq, |b, &nq| {
379+
let mut sim = StateVecSoA::new(nq);
380+
b.iter(|| {
381+
sim.reset();
382+
for q in 0..nq {
383+
sim.h(&[QubitId(q)]);
384+
}
385+
sim.flush();
386+
black_box(());
387+
});
388+
});
389+
}
390+
391+
group.finish();
392+
}
393+
299394
/// Benchmark parallel vs sequential execution for large state vectors.
300395
/// Only runs when the `parallel` feature is enabled on pecos-simulators.
301396
#[cfg(feature = "parallel")]

crates/pecos-engines/src/quantum.rs

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,11 @@ where
247247

248248
let mut measurements: Vec<usize> = Vec::new();
249249

250-
for cmd in &batch {
250+
// Use indexed iteration so we can batch consecutive MZ commands into
251+
// one simulator call, enabling joint-sampling optimizations.
252+
let mut cmd_idx = 0;
253+
while cmd_idx < batch.len() {
254+
let cmd = &batch[cmd_idx];
251255
match cmd.gate_type {
252256
GateType::X => {
253257
debug!("Processing X gate on qubits {:?}", cmd.qubits);
@@ -555,12 +559,24 @@ where
555559
}
556560
}
557561

558-
// TODO: Fix it so we have multiple result_ids or get rid of result ids...
562+
// Batch consecutive MZ commands into one simulator call.
563+
// This enables joint-sampling optimizations (fewer state vector passes).
559564
GateType::MZ | GateType::MeasureLeaked => {
560-
debug!("Processing measurement on qubits {:?}", cmd.qubits);
561-
let meas_results = self.simulator.mz(&cmd.qubits);
565+
// Collect qubits from consecutive MZ/MeasureLeaked commands
566+
let mut mz_qubits: Vec<QubitId> = cmd.qubits.to_vec();
567+
while cmd_idx + 1 < batch.len()
568+
&& matches!(
569+
batch[cmd_idx + 1].gate_type,
570+
GateType::MZ | GateType::MeasureLeaked
571+
)
572+
{
573+
cmd_idx += 1;
574+
mz_qubits.extend_from_slice(&batch[cmd_idx].qubits);
575+
}
576+
577+
debug!("Processing batched measurement on {} qubits", mz_qubits.len());
578+
let meas_results = self.simulator.mz(&mz_qubits);
562579
for meas_result in meas_results {
563-
// mz() outcome: true if projected to |1⟩, false if projected to |0⟩
564580
measurements.push(usize::from(meas_result.outcome));
565581
}
566582
}
@@ -677,6 +693,7 @@ where
677693
self.simulator.u2q(before, interaction, after, &pairs);
678694
}
679695
}
696+
cmd_idx += 1;
680697
}
681698

682699
// Create a message with the measurement results
@@ -860,7 +877,9 @@ impl Engine for SparseStabEngine {
860877
let batch = message.quantum_ops()?;
861878
let mut measurements: Vec<usize> = Vec::new();
862879

863-
for cmd in &batch {
880+
let mut cmd_idx = 0;
881+
while cmd_idx < batch.len() {
882+
let cmd = &batch[cmd_idx];
864883
match cmd.gate_type {
865884
// Single-qubit Clifford gates
866885
GateType::X
@@ -885,12 +904,22 @@ impl Engine for SparseStabEngine {
885904
| GateType::SYYdg => {
886905
self.process_two_qubit_gate(cmd.gate_type, &cmd.qubits);
887906
}
888-
// Special operations
907+
// Batch consecutive MZ commands
889908
GateType::MZ | GateType::MeasureLeaked => {
890-
debug!("Processing measurement on qubits {:?}", cmd.qubits);
891-
let meas_results = self.simulator.mz(&cmd.qubits);
909+
let mut mz_qubits: Vec<QubitId> = cmd.qubits.to_vec();
910+
while cmd_idx + 1 < batch.len()
911+
&& matches!(
912+
batch[cmd_idx + 1].gate_type,
913+
GateType::MZ | GateType::MeasureLeaked
914+
)
915+
{
916+
cmd_idx += 1;
917+
mz_qubits.extend_from_slice(&batch[cmd_idx].qubits);
918+
}
919+
920+
debug!("Processing batched measurement on {} qubits", mz_qubits.len());
921+
let meas_results = self.simulator.mz(&mz_qubits);
892922
for meas_result in meas_results {
893-
// mz() outcome: true if projected to |1⟩, false if projected to |0⟩
894923
measurements.push(usize::from(meas_result.outcome));
895924
}
896925
}
@@ -987,6 +1016,7 @@ impl Engine for SparseStabEngine {
9871016
)));
9881017
}
9891018
}
1019+
cmd_idx += 1;
9901020
}
9911021

9921022
// Create a message with the measurement results

0 commit comments

Comments
 (0)