Skip to content

Commit 5e427e4

Browse files
committed
Parallelize analysis sampling loops in expressibility, entanglement, and trainability
The three core analysis functions (compute_expressibility, compute_entanglement_capability, estimate_trainability) all sample many independent quantum circuits and aggregate their results. Their hot paths were sequential Python for-loops; with the new opt-in parallelism users get up to 2.4x wall-clock speedup with the default thread pool on modest batches, and larger speedups on big batches with the process pool. API: - Adds keyword-only parallel and max_workers parameters to: - compute_expressibility - compute_fidelity_distribution - compute_entanglement_capability - estimate_trainability - parallel accepts False (default, sequential), True or 'thread' (ThreadPoolExecutor), or 'process' (ProcessPoolExecutor with the initializer/initargs pattern so the encoding is pickled once per worker, not once per sample). - API mirrors BaseEncoding.get_circuits exactly so users learn one set of knobs for the whole library. - Process pool works for all three backends here, unlike get_circuits — workers exchange only NumPy arrays and floats, not PennyLane closures. Numerical determinism is the single most important guarantee: - The RNG is fully consumed in the main process before any work is dispatched. For numpy.random.Generator, batched rng.uniform calls produce an identical sequence to per-iteration calls with the same seed, so the existing per-iteration code path is preserved byte-for-byte while gaining the option to dispatch work in parallel. - For a fixed seed, every parallel mode produces output identical to the sequential baseline. This is the property the new tests enforce — if it ever regresses, the analysis pipeline silently loses reproducibility. Implementation details: - New encoding_atlas.analysis._parallel module exposing the small ParallelArg type alias and resolve_parallel_mode normaliser shared across all three analysis files. - Each analysis file adds top-level worker plumbing (initializer/initargs pattern, picklable worker compute function, shared _compute_one_* helper used by all code paths so the arithmetic is identical regardless of mode). - Trainability is the trickiest: failures must remain tolerated per-sample (raising from a worker would tear down the entire pool), and successful gradients must still be packed contiguously from index 0 to match the sequential pack-on-success semantics. - Bad parallel values raise a clean ValueError up front (not wrapped in AnalysisError by the broad sampling try/except). Tests (tests/unit/analysis/test_parallel.py, 61 cases across 6 classes) cover: every parallel mode x every installed backend x both return_details=False and =True, the lower-level compute_fidelity_distribution, all four supported parallel values plus six rejected values for the normaliser, Scott measure with explicit k, observable kwarg propagating into trainability workers, n_samples=1 short-circuit, max_workers=1 edge case, and clean ValueError on bad parallel. Determinism is verified by exact equality (not approximate) at the numerical output level. Full test suite (4502 not-slow + 382 slow with optional backends) passes; ruff, black, build, and mkdocs --strict all clean. All 1239 pre-existing analysis tests still pass — defaults preserve previous sequential behavior.
1 parent 933fa31 commit 5e427e4

5 files changed

Lines changed: 1049 additions & 103 deletions

File tree

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Shared parallelization helpers for analysis sampling loops.
2+
3+
The three core analysis functions (``compute_expressibility``,
4+
``compute_entanglement_capability``, ``estimate_trainability``) all sample
5+
many independent quantum circuits and aggregate their results. Their hot
6+
paths are embarrassingly parallel, and this module supplies the small
7+
amount of shared infrastructure that lets each of them dispatch
8+
sequentially, on a thread pool, or on a process pool with the same public
9+
API.
10+
11+
API consistency
12+
---------------
13+
The ``parallel`` argument mirrors the one already exposed on
14+
:meth:`encoding_atlas.core.base.BaseEncoding.get_circuits`:
15+
16+
* ``False`` (default) — sequential, no executor overhead.
17+
* ``True`` or ``'thread'`` — :class:`concurrent.futures.ThreadPoolExecutor`.
18+
* ``'process'`` — :class:`concurrent.futures.ProcessPoolExecutor`.
19+
20+
Determinism
21+
-----------
22+
The analysis callers are required to pre-generate every random input in
23+
the *main* process before dispatching. Workers receive the inputs and
24+
perform only deterministic computation (statevector simulation,
25+
gradient evaluation, entanglement measure). This guarantees that for a
26+
fixed seed the numerical output is identical across parallelization
27+
modes — sequential, thread pool, and process pool all produce the same
28+
result, byte-for-byte where floats allow.
29+
30+
Pickling caveats
31+
----------------
32+
ProcessPoolExecutor exchanges all arguments and return values via
33+
``pickle``. The analysis workers do not return circuit objects (which
34+
would fail for PennyLane's local-closure qfuncs); they only return
35+
numpy arrays / floats / Python tuples, which are universally
36+
picklable. This is what allows ``parallel='process'`` to work for all
37+
three backends in the analysis path, unlike
38+
:meth:`BaseEncoding.get_circuits`.
39+
"""
40+
41+
from __future__ import annotations
42+
43+
from typing import Literal, Union
44+
45+
# Public type alias re-used by every analysis function's signature.
46+
ParallelArg = Union[bool, Literal["thread", "process"]]
47+
ParallelMode = Literal["sequential", "thread", "process"]
48+
49+
50+
def resolve_parallel_mode(parallel: ParallelArg) -> ParallelMode:
51+
"""Normalize the public ``parallel`` argument to an internal mode tag.
52+
53+
Parameters
54+
----------
55+
parallel : bool or {'thread', 'process'}
56+
Public-facing parallelization selector. ``True`` is preserved as
57+
an alias for ``'thread'`` so callers don't have to update their
58+
existing ``parallel=True`` invocations.
59+
60+
Returns
61+
-------
62+
{'sequential', 'thread', 'process'}
63+
Internal mode label.
64+
65+
Raises
66+
------
67+
ValueError
68+
If ``parallel`` is none of the accepted values. The error message
69+
lists exactly what is accepted so users can self-correct quickly.
70+
"""
71+
if parallel is False:
72+
return "sequential"
73+
if parallel is True or parallel == "thread":
74+
return "thread"
75+
if parallel == "process":
76+
return "process"
77+
raise ValueError(
78+
f"parallel must be False, True, 'thread', or 'process', " f"got {parallel!r}"
79+
)

src/encoding_atlas/analysis/entanglement.py

Lines changed: 186 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,14 @@
116116

117117
import logging
118118
import warnings
119+
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
119120
from itertools import combinations
120121
from typing import Any, Literal, TypedDict, Union, overload
121122

122123
import numpy as np
123124
from numpy.typing import NDArray
124125

126+
from encoding_atlas.analysis._parallel import ParallelArg, resolve_parallel_mode
125127
from encoding_atlas.analysis._utils import (
126128
compute_purity,
127129
create_rng,
@@ -252,6 +254,83 @@ class EntanglementResult(TypedDict):
252254
_MAX_VERBOSE_QUBITS: int = 10
253255

254256

257+
# =============================================================================
258+
# Process-pool worker plumbing (top-level for picklability)
259+
# =============================================================================
260+
#
261+
# These globals are populated *once per worker process* by
262+
# ``_entanglement_worker_init`` so the encoding + measure configuration
263+
# travel across the wire only once per worker, not once per sample.
264+
# Same pattern as :mod:`encoding_atlas.core.base` and
265+
# :mod:`encoding_atlas.analysis.expressibility`.
266+
267+
_ENT_WORKER_ENCODING: BaseEncoding | None = None
268+
_ENT_WORKER_BACKEND: str | None = None
269+
_ENT_WORKER_MEASURE: str | None = None
270+
_ENT_WORKER_N_QUBITS: int | None = None
271+
_ENT_WORKER_SCOTT_K: int | None = None
272+
273+
274+
def _entanglement_worker_init(
275+
encoding: BaseEncoding,
276+
backend: str,
277+
measure: str,
278+
n_qubits: int,
279+
scott_k: int | None,
280+
) -> None:
281+
"""ProcessPoolExecutor initializer — runs once per worker process."""
282+
global _ENT_WORKER_ENCODING, _ENT_WORKER_BACKEND, _ENT_WORKER_MEASURE
283+
global _ENT_WORKER_N_QUBITS, _ENT_WORKER_SCOTT_K
284+
_ENT_WORKER_ENCODING = encoding
285+
_ENT_WORKER_BACKEND = backend
286+
_ENT_WORKER_MEASURE = measure
287+
_ENT_WORKER_N_QUBITS = n_qubits
288+
_ENT_WORKER_SCOTT_K = scott_k
289+
290+
291+
def _entanglement_worker_compute(
292+
x: NDArray[np.floating[Any]],
293+
) -> tuple[float, NDArray[np.floating[Any]]]:
294+
"""Worker entrypoint: compute entanglement for one sample input."""
295+
assert (
296+
_ENT_WORKER_ENCODING is not None
297+
and _ENT_WORKER_BACKEND is not None
298+
and _ENT_WORKER_MEASURE is not None
299+
and _ENT_WORKER_N_QUBITS is not None
300+
), "Process pool worker invoked before initializer ran"
301+
return _compute_one_entanglement(
302+
_ENT_WORKER_ENCODING,
303+
x,
304+
_ENT_WORKER_BACKEND,
305+
_ENT_WORKER_MEASURE,
306+
_ENT_WORKER_N_QUBITS,
307+
_ENT_WORKER_SCOTT_K,
308+
)
309+
310+
311+
def _compute_one_entanglement(
312+
encoding: BaseEncoding,
313+
x: NDArray[np.floating[Any]],
314+
backend: str,
315+
measure: str,
316+
n_qubits: int,
317+
scott_k: int | None,
318+
) -> tuple[float, NDArray[np.floating[Any]]]:
319+
"""Simulate ``x`` and return ``(ent_value, per_qubit)``.
320+
321+
Shared by the sequential, thread, and process code paths so the
322+
arithmetic is identical regardless of parallelization mode.
323+
"""
324+
statevector = simulate_encoding_statevector(encoding, x, backend=backend)
325+
if measure == "meyer_wallach":
326+
return compute_meyer_wallach_with_breakdown(statevector, n_qubits)
327+
# measure == "scott"
328+
assert scott_k is not None # caller responsibility
329+
ent_value = compute_scott_measure(statevector, n_qubits, k=scott_k)
330+
per_qubit = np.zeros(n_qubits, dtype=np.float64)
331+
return ent_value, per_qubit
332+
333+
255334
# =============================================================================
256335
# Main Public Function
257336
# =============================================================================
@@ -268,6 +347,8 @@ def compute_entanglement_capability(
268347
scott_k: int | None = ...,
269348
return_details: Literal[False] = ...,
270349
verbose: bool = ...,
350+
parallel: ParallelArg = ...,
351+
max_workers: int | None = ...,
271352
) -> float: ...
272353

273354

@@ -282,6 +363,8 @@ def compute_entanglement_capability(
282363
scott_k: int | None = ...,
283364
return_details: Literal[True] = ...,
284365
verbose: bool = ...,
366+
parallel: ParallelArg = ...,
367+
max_workers: int | None = ...,
285368
) -> EntanglementResult: ...
286369

287370

@@ -295,6 +378,8 @@ def compute_entanglement_capability(
295378
scott_k: int | None = None,
296379
return_details: bool = False,
297380
verbose: bool = False,
381+
parallel: ParallelArg = False,
382+
max_workers: int | None = None,
298383
) -> Union[float, EntanglementResult]:
299384
"""Compute the entanglement capability of a quantum encoding.
300385
@@ -340,6 +425,23 @@ def compute_entanglement_capability(
340425
If False, return only the entanglement capability score.
341426
verbose : bool, default=False
342427
If True, log progress during computation.
428+
parallel : bool or {'thread', 'process'}, default=False
429+
Parallel-dispatch mode for the per-sample simulation +
430+
entanglement-measure computation.
431+
432+
- ``False`` (default) — sequential, no executor overhead.
433+
- ``True`` or ``'thread'`` — :class:`ThreadPoolExecutor`.
434+
- ``'process'`` — :class:`ProcessPoolExecutor` with the encoding
435+
pickled once per worker. Workers exchange only float / NumPy
436+
arrays, so process-pool parallelism works with **all** three
437+
backends here (unlike ``BaseEncoding.get_circuits`` where
438+
PennyLane's local-closure qfuncs prevent process-pool use).
439+
440+
Output is numerically identical across all modes for a fixed
441+
``seed`` — the RNG is fully consumed in the main process before
442+
any work is dispatched.
443+
max_workers : int or None, default=None
444+
Maximum number of workers when ``parallel`` is enabled.
343445
344446
Returns
345447
-------
@@ -470,6 +572,9 @@ def compute_entanglement_capability(
470572
f"backend must be 'pennylane', 'qiskit', or 'cirq', got {backend!r}"
471573
)
472574

575+
# Validate parallel argument upfront for a clean ValueError on bad input.
576+
mode = resolve_parallel_mode(parallel)
577+
473578
# Validate and resolve scott_k parameter
474579
effective_scott_k: int | None = None
475580
if measure == "scott":
@@ -528,60 +633,104 @@ def compute_entanglement_capability(
528633
entanglement_samples = np.zeros(n_samples, dtype=np.float64)
529634
per_qubit_sum = np.zeros(n_qubits, dtype=np.float64)
530635

636+
# Pre-generate all random inputs in the main process. ``np.random.Generator``
637+
# produces an identical sequence whether called once with size=(n_samples,
638+
# n_features) or n_samples times with size=n_features, so this preserves
639+
# the original seeded output exactly.
640+
X_samples = rng.uniform(
641+
input_range[0], input_range[1], size=(n_samples, n_features)
642+
).astype(np.float64)
643+
531644
# Progress logging interval (every 10%)
532645
log_interval = max(1, n_samples // 10)
533646

534-
for i in range(n_samples):
535-
# Generate random input features
536-
x = rng.uniform(input_range[0], input_range[1], size=n_features)
537-
x = x.astype(np.float64)
538-
539-
try:
540-
# Simulate circuit to get statevector
541-
statevector = simulate_encoding_statevector(encoding, x, backend=backend)
542-
543-
# Compute entanglement measure
544-
if measure == "meyer_wallach":
545-
ent_value, per_qubit = compute_meyer_wallach_with_breakdown(
546-
statevector, n_qubits
647+
if mode == "sequential" or n_samples <= 1:
648+
# Sequential path: keep the inline loop so progress logging and
649+
# per-sample error context remain available.
650+
for i in range(n_samples):
651+
x = X_samples[i]
652+
try:
653+
ent_value, per_qubit = _compute_one_entanglement(
654+
encoding, x, backend, measure, n_qubits, effective_scott_k
547655
)
548-
else: # scott
549-
# effective_scott_k is guaranteed to be valid here
550-
assert effective_scott_k is not None # Type narrowing for mypy
551-
ent_value = compute_scott_measure(
552-
statevector, n_qubits, k=effective_scott_k
656+
entanglement_samples[i] = ent_value
657+
per_qubit_sum += per_qubit
658+
except SimulationError:
659+
raise
660+
except Exception as e:
661+
raise SimulationError(
662+
f"Entanglement computation failed at sample {i}: {e}",
663+
backend=backend,
664+
details={
665+
"sample_index": i,
666+
"input": x.tolist(),
667+
"error_type": type(e).__name__,
668+
"measure": measure,
669+
"scott_k": effective_scott_k,
670+
},
671+
) from e
672+
673+
if verbose and (i + 1) % log_interval == 0:
674+
current_mean = np.mean(entanglement_samples[: i + 1])
675+
_logger.debug(
676+
"Processed %d/%d samples (current mean: %.4f)",
677+
i + 1,
678+
n_samples,
679+
current_mean,
553680
)
554-
# For Scott measure, per-qubit breakdown is not directly available
555-
per_qubit = np.zeros(n_qubits, dtype=np.float64)
556-
557-
entanglement_samples[i] = ent_value
558-
per_qubit_sum += per_qubit
559-
681+
else:
682+
# Parallel path. Wrap simulation errors with a generic context — the
683+
# specific sample index is no longer well-defined when work is being
684+
# done out of order across workers, but the upstream exception's
685+
# traceback still pinpoints the failure.
686+
try:
687+
if mode == "thread":
688+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
689+
results = list(
690+
executor.map(
691+
lambda x: _compute_one_entanglement(
692+
encoding,
693+
x,
694+
backend,
695+
measure,
696+
n_qubits,
697+
effective_scott_k,
698+
),
699+
X_samples,
700+
)
701+
)
702+
else: # mode == "process"
703+
with ProcessPoolExecutor(
704+
max_workers=max_workers,
705+
initializer=_entanglement_worker_init,
706+
initargs=(
707+
encoding,
708+
backend,
709+
measure,
710+
n_qubits,
711+
effective_scott_k,
712+
),
713+
) as executor:
714+
results = list(
715+
executor.map(_entanglement_worker_compute, X_samples)
716+
)
560717
except SimulationError:
561-
# Re-raise simulation errors with context
562718
raise
563719
except Exception as e:
564720
raise SimulationError(
565-
f"Entanglement computation failed at sample {i}: {e}",
721+
f"Entanglement computation failed in {mode} pool: {e}",
566722
backend=backend,
567723
details={
568-
"sample_index": i,
569-
"input": x.tolist(),
570724
"error_type": type(e).__name__,
571725
"measure": measure,
572726
"scott_k": effective_scott_k,
727+
"parallel": mode,
573728
},
574729
) from e
575730

576-
# Progress logging
577-
if verbose and (i + 1) % log_interval == 0:
578-
current_mean = np.mean(entanglement_samples[: i + 1])
579-
_logger.debug(
580-
"Processed %d/%d samples (current mean: %.4f)",
581-
i + 1,
582-
n_samples,
583-
current_mean,
584-
)
731+
for i, (ent_value, per_qubit) in enumerate(results):
732+
entanglement_samples[i] = ent_value
733+
per_qubit_sum += per_qubit
585734

586735
# -------------------------------------------------------------------------
587736
# Compute Statistics

0 commit comments

Comments
 (0)