Skip to content

Commit bec4bc1

Browse files
committed
Vectorize expressibility sampling and batch fidelity computation
1 parent 0fe6be1 commit bec4bc1

2 files changed

Lines changed: 69 additions & 24 deletions

File tree

src/encoding_atlas/analysis/_utils.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -512,25 +512,26 @@ def simulate_encoding_statevectors_batch(
512512
encoding: BaseEncoding,
513513
X: NDArray[np.floating[Any]],
514514
backend: Literal["pennylane", "qiskit", "cirq"] = "pennylane",
515-
) -> list[StatevectorType]:
515+
) -> NDArray[np.complexfloating[Any, Any]]:
516516
"""Simulate encoding circuits for multiple input vectors.
517517
518518
This is a convenience function that applies :func:`simulate_encoding_statevector`
519-
to each row of a 2D input array.
519+
to each row of a 2D input array and returns a pre-allocated 2D array of
520+
statevectors.
520521
521522
Parameters
522523
----------
523524
encoding : BaseEncoding
524525
The encoding instance to simulate.
525526
X : NDArray[np.floating]
526527
Input data array of shape ``(n_samples, n_features)``.
527-
backend : {"pennylane", "qiskit"}, default="pennylane"
528+
backend : {"pennylane", "qiskit", "cirq"}, default="pennylane"
528529
The quantum simulation backend to use.
529530
530531
Returns
531532
-------
532-
list[StatevectorType]
533-
List of statevectors, one for each input sample.
533+
NDArray[np.complexfloating], shape ``(n_samples, 2**n_qubits)``
534+
2D array of statevectors, one row per input sample.
534535
535536
Raises
536537
------
@@ -554,25 +555,28 @@ def simulate_encoding_statevectors_batch(
554555
if X_array.ndim != 2:
555556
raise ValidationError(f"Input X must be 2D array, got shape {X_array.shape}")
556557

558+
n_samples = X_array.shape[0]
559+
dim = 2**encoding.n_qubits
560+
557561
_logger.debug(
558562
"Batch simulating %d samples for encoding %s",
559-
X_array.shape[0],
563+
n_samples,
560564
encoding.__class__.__name__,
561565
)
562566

563-
statevectors = []
567+
states = np.zeros((n_samples, dim), dtype=np.complex128)
564568
for i, x in enumerate(X_array):
565569
try:
566570
state = simulate_encoding_statevector(encoding, x, backend)
567-
statevectors.append(state)
571+
states[i] = np.asarray(state, dtype=np.complex128).ravel()
568572
except SimulationError as e:
569573
raise SimulationError(
570574
f"Simulation failed for sample {i}: {e}",
571575
backend=backend,
572576
details={"sample_index": i, "original_error": str(e)},
573577
) from e
574578

575-
return statevectors
579+
return states
576580

577581

578582
def _simulate_pennylane(
@@ -1469,6 +1473,53 @@ def compute_fidelity(
14691473
return fidelity
14701474

14711475

1476+
def _compute_fidelities_batch(
1477+
states1: NDArray[np.complexfloating[Any, Any]],
1478+
states2: NDArray[np.complexfloating[Any, Any]],
1479+
) -> NDArray[np.floating[Any]]:
1480+
"""Compute fidelities between pairs of statevectors in batch.
1481+
1482+
Vectorized version of :func:`compute_fidelity` for arrays of states.
1483+
Computes F_i = |⟨ψ₁ⁱ|ψ₂ⁱ⟩|² for each pair (i).
1484+
1485+
Parameters
1486+
----------
1487+
states1 : NDArray[np.complexfloating], shape ``(n, d)``
1488+
First set of statevectors.
1489+
states2 : NDArray[np.complexfloating], shape ``(n, d)``
1490+
Second set of statevectors.
1491+
1492+
Returns
1493+
-------
1494+
NDArray[np.floating], shape ``(n,)``
1495+
Fidelity values, each in [0, 1].
1496+
1497+
Raises
1498+
------
1499+
ValueError
1500+
If shapes of ``states1`` and ``states2`` do not match.
1501+
ValidationError
1502+
If any state contains NaN or infinite values.
1503+
"""
1504+
if states1.shape != states2.shape:
1505+
raise ValueError(
1506+
f"States must have same shape: got {states1.shape} and {states2.shape}"
1507+
)
1508+
if states1.ndim != 2:
1509+
raise ValueError(f"States must be 2D arrays, got ndim={states1.ndim}")
1510+
1511+
if np.any(np.isnan(states1)) or np.any(np.isinf(states1)):
1512+
raise ValidationError("states1 contains NaN or infinite values")
1513+
if np.any(np.isnan(states2)) or np.any(np.isinf(states2)):
1514+
raise ValidationError("states2 contains NaN or infinite values")
1515+
1516+
overlaps = np.sum(np.conj(states1) * states2, axis=1)
1517+
fidelities = np.abs(overlaps) ** 2
1518+
fidelities = np.clip(fidelities, 0.0, 1.0).astype(np.float64)
1519+
1520+
return fidelities
1521+
1522+
14721523
def compute_purity(
14731524
density_matrix: DensityMatrixType,
14741525
) -> float:

src/encoding_atlas/analysis/expressibility.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,11 +1040,9 @@ def _sample_fidelities(
10401040
) -> NDArray[np.floating[Any]]:
10411041
"""Sample fidelities between random input pairs.
10421042
1043-
This is the core sampling loop for expressibility computation.
1044-
For each of n_samples iterations:
1045-
1. Generate two random input vectors x1, x2
1046-
2. Simulate encoding to get states |ψ(x1)⟩, |ψ(x2)⟩
1047-
3. Compute fidelity F = |⟨ψ(x1)|ψ(x2)⟩|²
1043+
Generates random input pairs in batch, simulates encoding circuits,
1044+
and computes fidelities. RNG calls are vectorized into single batch
1045+
operations for reduced overhead.
10481046
10491047
Parameters
10501048
----------
@@ -1073,24 +1071,20 @@ def _sample_fidelities(
10731071
SimulationError
10741072
If circuit simulation fails.
10751073
"""
1076-
# TODO: Vectorize this loop — batch rng.uniform calls and simulation
1077-
# calls to reduce per-sample overhead. This is the main performance
1078-
# bottleneck for large n_samples. See simulate_encoding_statevectors_batch
1079-
# in _utils.py for a possible starting point.
1074+
# Batch RNG generation — single call per input set
1075+
X1 = rng.uniform(input_range[0], input_range[1], size=(n_samples, n_features))
1076+
X2 = rng.uniform(input_range[0], input_range[1], size=(n_samples, n_features))
1077+
10801078
fidelities = np.zeros(n_samples, dtype=np.float64)
10811079

10821080
# Logging interval (log every 10% of progress)
10831081
log_interval = max(1, n_samples // 10)
10841082

10851083
for i in range(n_samples):
1086-
# Generate two random input vectors
1087-
x1 = rng.uniform(input_range[0], input_range[1], size=n_features)
1088-
x2 = rng.uniform(input_range[0], input_range[1], size=n_features)
1089-
10901084
try:
10911085
# Simulate encoding to get statevectors
1092-
state1 = simulate_encoding_statevector(encoding, x1, backend=backend)
1093-
state2 = simulate_encoding_statevector(encoding, x2, backend=backend)
1086+
state1 = simulate_encoding_statevector(encoding, X1[i], backend=backend)
1087+
state2 = simulate_encoding_statevector(encoding, X2[i], backend=backend)
10941088

10951089
# Compute fidelity: F = |⟨ψ₁|ψ₂⟩|²
10961090
fidelity = compute_fidelity(state1, state2)

0 commit comments

Comments
 (0)