1010# or implied. See the License for the specific language governing permissions and limitations under
1111# the License.
1212
13- """Integration tests for pickle-based multiprocessing of simulators."""
13+ """Integration tests for pickle-based multiprocessing of simulators.
14+
15+ Worker functions are defined in ``pecos._mp_workers`` (an installed module)
16+ rather than in this test file, because ``multiprocessing`` with the ``spawn``
17+ start method (default on macOS/Windows) requires workers to be importable
18+ by the child process. Test files live outside the installed package and
19+ cannot be imported by spawned children.
20+ """
1421
1522import multiprocessing
1623import pickle
1724import sys
1825
1926import pytest
27+ from pecos ._mp_workers import (
28+ deserialize_and_call ,
29+ run_callable_worker ,
30+ sim_run_from_bytes ,
31+ )
32+ from pecos .engines .hybrid_engine_multiprocessing import worker_wrapper
2033from pecos_rslib import CoinToss , PauliProp , SparseSim , StateVec
2134
22-
23- def _statevec_worker (sim_bytes : bytes ) -> int :
24- sim = pickle .loads (sim_bytes )
25- sim .run_1q_gate ("H" , 0 )
26- return sim .num_qubits
27-
28-
29- def _sparsesim_worker (sim_bytes : bytes ) -> int :
30- sim = pickle .loads (sim_bytes )
31- sim .run_1q_gate ("H" , 0 )
32- return sim .num_qubits
33-
34-
35- def _cointoss_worker (sim_bytes : bytes ) -> int :
36- sim = pickle .loads (sim_bytes )
37- sim .run_measure (0 )
38- return sim .num_qubits
39-
40-
41- def _pauliprop_worker (sim_bytes : bytes ) -> int :
42- sim = pickle .loads (sim_bytes )
43- sim .h (0 )
44- return sim .weight ()
45-
46-
4735# Use fork context on Linux (fast, avoids spawn serialization issues with test files).
4836# On macOS/Windows where fork is unavailable or unsafe, use spawn.
4937_MP_CONTEXT = "fork" if sys .platform == "linux" else "spawn"
5038_POOL_TIMEOUT = 60 # seconds -- fail fast instead of hanging CI
5139
5240
41+ def _get_pool_context () -> multiprocessing .context .BaseContext :
42+ return multiprocessing .get_context (_MP_CONTEXT )
43+
44+
45+ # ---------------------------------------------------------------------------
46+ # Basic pickle round-trip tests via deserialize_and_call
47+ # ---------------------------------------------------------------------------
48+
49+
5350@pytest .mark .timeout (120 )
5451class TestMultiprocessingStateVec :
5552 """Tests for multiprocessing StateVec simulators via pickle."""
@@ -59,9 +56,10 @@ def test_pool_map(self) -> None:
5956 sim = StateVec (3 , seed = 42 )
6057 sim .run_1q_gate ("H" , 0 )
6158 sim_bytes = pickle .dumps (sim )
62- ctx = multiprocessing .get_context (_MP_CONTEXT )
59+ args = [(sim_bytes , "run_1q_gate" , ("H" , 0 ), "num_qubits" , ())] * 2
60+ ctx = _get_pool_context ()
6361 with ctx .Pool (processes = 2 ) as pool :
64- results = pool .map_async (_statevec_worker , [ sim_bytes , sim_bytes ] ).get (
62+ results = pool .map_async (deserialize_and_call , args ).get (
6563 timeout = _POOL_TIMEOUT ,
6664 )
6765 assert results == [3 , 3 ]
@@ -77,9 +75,10 @@ def test_pool_map(self) -> None:
7775 sim .run_1q_gate ("H" , 0 )
7876 sim .run_2q_gate ("CX" , (0 , 1 ), None )
7977 sim_bytes = pickle .dumps (sim )
80- ctx = multiprocessing .get_context (_MP_CONTEXT )
78+ args = [(sim_bytes , "run_1q_gate" , ("H" , 0 ), "num_qubits" , ())] * 2
79+ ctx = _get_pool_context ()
8180 with ctx .Pool (processes = 2 ) as pool :
82- results = pool .map_async (_sparsesim_worker , [ sim_bytes , sim_bytes ] ).get (
81+ results = pool .map_async (deserialize_and_call , args ).get (
8382 timeout = _POOL_TIMEOUT ,
8483 )
8584 assert results == [4 , 4 ]
@@ -93,9 +92,10 @@ def test_pool_map(self) -> None:
9392 """Test CoinToss serialization works with multiprocessing Pool.map."""
9493 sim = CoinToss (5 , prob = 0.3 )
9594 sim_bytes = pickle .dumps (sim )
96- ctx = multiprocessing .get_context (_MP_CONTEXT )
95+ args = [(sim_bytes , "run_measure" , (0 ,), "num_qubits" , ())] * 2
96+ ctx = _get_pool_context ()
9797 with ctx .Pool (processes = 2 ) as pool :
98- results = pool .map_async (_cointoss_worker , [ sim_bytes , sim_bytes ] ).get (
98+ results = pool .map_async (deserialize_and_call , args ).get (
9999 timeout = _POOL_TIMEOUT ,
100100 )
101101 assert results == [5 , 5 ]
@@ -110,10 +110,163 @@ def test_pool_map(self) -> None:
110110 sim = PauliProp (num_qubits = 3 , track_sign = True )
111111 sim .add_x (0 )
112112 sim_bytes = pickle .dumps (sim )
113- ctx = multiprocessing .get_context (_MP_CONTEXT )
113+ args = [(sim_bytes , "h" , (0 ,), "weight" , ())] * 2
114+ ctx = _get_pool_context ()
114115 with ctx .Pool (processes = 2 ) as pool :
115- results = pool .map_async (_pauliprop_worker , [ sim_bytes , sim_bytes ] ).get (
116+ results = pool .map_async (deserialize_and_call , args ).get (
116117 timeout = _POOL_TIMEOUT ,
117118 )
118119 # After H on qubit 0: X->Z, so weight should still be 1
119120 assert all (r == 1 for r in results )
121+
122+
123+ # ---------------------------------------------------------------------------
124+ # Production-pattern tests: Manager queue + worker_wrapper
125+ #
126+ # These mirror the pattern used in hybrid_engine_multiprocessing.run_multisim:
127+ # 1. Create a Manager().Queue() for inter-process messaging
128+ # 2. Pass (queue, callable, kwargs, index) to worker_wrapper via pool.map
129+ # 3. worker_wrapper redirects stdout/stderr to WriteStream on the queue
130+ # 4. worker_wrapper calls the callable and returns (result_dict, run_info)
131+ # 5. Parent drains the queue and aggregates results
132+ # ---------------------------------------------------------------------------
133+
134+
135+ @pytest .mark .timeout (120 )
136+ class TestWorkerWrapperPattern :
137+ """Tests that mirror the production worker_wrapper + Manager queue pattern."""
138+
139+ def test_worker_wrapper_with_statevec (self ) -> None :
140+ """Test the production worker_wrapper pattern with StateVec."""
141+ sim = StateVec (3 , seed = 42 )
142+ sim .run_1q_gate ("H" , 0 )
143+ sim_bytes = pickle .dumps (sim )
144+
145+ ctx = _get_pool_context ()
146+ manager = ctx .Manager ()
147+ queue = manager .Queue ()
148+
149+ kwargs = {
150+ "sim_bytes" : sim_bytes ,
151+ "method" : "run_1q_gate" ,
152+ "method_args" : ("H" , 0 ),
153+ "result_attr" : "num_qubits" ,
154+ "seed" : 1 ,
155+ "shots" : 1 ,
156+ "foreign_object" : None ,
157+ }
158+ worker_args = [
159+ (queue , sim_run_from_bytes , {** kwargs , "seed" : 1 }, 0 ),
160+ (queue , sim_run_from_bytes , {** kwargs , "seed" : 2 }, 1 ),
161+ ]
162+
163+ with ctx .Pool (processes = 2 ) as pool :
164+ presults = pool .map_async (worker_wrapper , worker_args ).get (
165+ timeout = _POOL_TIMEOUT ,
166+ )
167+
168+ for result_dict , run_info in presults :
169+ assert result_dict == {"measurements" : [3 ]}
170+ assert "pid" in run_info
171+ assert "i" in run_info
172+
173+ def test_worker_wrapper_with_sparsesim (self ) -> None :
174+ """Test the production worker_wrapper pattern with SparseSim."""
175+ sim = SparseSim (4 )
176+ sim .run_1q_gate ("H" , 0 )
177+ sim .run_2q_gate ("CX" , (0 , 1 ), None )
178+ sim_bytes = pickle .dumps (sim )
179+
180+ ctx = _get_pool_context ()
181+ manager = ctx .Manager ()
182+ queue = manager .Queue ()
183+
184+ kwargs = {
185+ "sim_bytes" : sim_bytes ,
186+ "method" : "run_1q_gate" ,
187+ "method_args" : ("H" , 0 ),
188+ "result_attr" : "num_qubits" ,
189+ "seed" : 1 ,
190+ "shots" : 1 ,
191+ "foreign_object" : None ,
192+ }
193+ worker_args = [
194+ (queue , sim_run_from_bytes , {** kwargs , "seed" : 1 }, 0 ),
195+ (queue , sim_run_from_bytes , {** kwargs , "seed" : 2 }, 1 ),
196+ ]
197+
198+ with ctx .Pool (processes = 2 ) as pool :
199+ presults = pool .map_async (worker_wrapper , worker_args ).get (
200+ timeout = _POOL_TIMEOUT ,
201+ )
202+
203+ for result_dict , run_info in presults :
204+ assert result_dict == {"measurements" : [4 ]}
205+ assert "pid" in run_info
206+
207+ def test_queue_message_passing (self ) -> None :
208+ """Test that stdout/stderr from workers is captured on the queue."""
209+ sim = StateVec (2 , seed = 0 )
210+ sim_bytes = pickle .dumps (sim )
211+
212+ ctx = _get_pool_context ()
213+ manager = ctx .Manager ()
214+ queue = manager .Queue ()
215+
216+ kwargs = {
217+ "sim_bytes" : sim_bytes ,
218+ "method" : "run_1q_gate" ,
219+ "method_args" : ("H" , 0 ),
220+ "result_attr" : "num_qubits" ,
221+ "seed" : 1 ,
222+ "shots" : 1 ,
223+ "foreign_object" : None ,
224+ }
225+ worker_args = [(queue , sim_run_from_bytes , kwargs , 0 )]
226+
227+ with ctx .Pool (processes = 1 ) as pool :
228+ pool .map_async (worker_wrapper , worker_args ).get (timeout = _POOL_TIMEOUT )
229+
230+ # The queue may contain stdout/stderr messages captured by WriteStream.
231+ # We just verify the queue is accessible and drainable (no deadlock).
232+ messages = []
233+ while not queue .empty ():
234+ messages .append (queue .get ())
235+ # Messages are (pid, stream_type, data) tuples if any output occurred.
236+ for msg in messages :
237+ assert len (msg ) == 3
238+
239+
240+ # ---------------------------------------------------------------------------
241+ # Callable-with-kwargs pattern tests (run_callable_worker)
242+ #
243+ # This tests the simpler pattern where a callable + kwargs dict are passed
244+ # to the pool, similar to how run_multisim passes eng.run + kwargs to workers.
245+ # ---------------------------------------------------------------------------
246+
247+
248+ @pytest .mark .timeout (120 )
249+ class TestRunCallableWorker :
250+ """Tests for the callable+kwargs worker pattern used in production."""
251+
252+ def test_callable_worker_statevec (self ) -> None :
253+ """Test passing a callable + kwargs through the pool."""
254+ sim = StateVec (3 , seed = 42 )
255+ sim .run_1q_gate ("H" , 0 )
256+ sim_bytes = pickle .dumps (sim )
257+
258+ kwargs = {
259+ "sim_bytes" : sim_bytes ,
260+ "method" : "run_1q_gate" ,
261+ "method_args" : ("H" , 0 ),
262+ "result_attr" : "num_qubits" ,
263+ }
264+ args = [(sim_run_from_bytes , kwargs )] * 2
265+
266+ ctx = _get_pool_context ()
267+ with ctx .Pool (processes = 2 ) as pool :
268+ results = pool .map_async (run_callable_worker , args ).get (
269+ timeout = _POOL_TIMEOUT ,
270+ )
271+
272+ assert results == [{"measurements" : [3 ]}, {"measurements" : [3 ]}]
0 commit comments