11"""Multiprocessing backend for distributed evaluation.
22
3- Uses a module-level variable to pass the objective function to worker
4- processes, avoiding pickling issues. Only the parameter dicts (plain
5- dicts of floats/strings) are serialized over the pool's pipe.
3+ On platforms with ``fork`` (Linux, older macOS), uses a module-level
4+ variable to pass the objective function to workers via shared memory,
5+ avoiding pickling entirely. On ``spawn`` platforms (Windows, macOS
6+ 3.14+), the function is pickled alongside each parameter dict using
7+ ``starmap``, which requires it to be importable at module level.
68"""
79
810from __future__ import annotations
1214from ._base import BaseDistribution
1315
1416# Shared between parent and forked workers via copy-on-write memory.
15- # Set by _distribute() before pool.map, cleared after .
17+ # Only used with the fork context .
1618_worker_func = None
1719
1820
1921def _eval_single (params ):
20- """Evaluate a single parameter dict in a worker process ."""
22+ """Worker entry point for fork context ."""
2123 return _worker_func (params )
2224
2325
26+ def _eval_with_func (func , params ):
27+ """Worker entry point for spawn context (func is pickled per call)."""
28+ return func (params )
29+
30+
2431class Multiprocessing (BaseDistribution ):
2532 """Distribute evaluations across local processes via multiprocessing.Pool.
2633
@@ -44,9 +51,10 @@ def objective(para):
4451
4552 Notes
4653 -----
47- Uses the ``fork`` multiprocessing context on Linux/macOS so that
48- the objective function does not need to be picklable. On systems
49- where ``fork`` is unavailable the function must be defined at module
54+ Prefers the ``fork`` context (Linux/macOS) so the objective function
55+ is inherited by workers without pickling. Falls back to the platform
56+ default (``spawn`` on Windows, macOS 3.14+) where ``fork`` is
57+ unavailable. With ``spawn``, the objective must be defined at module
5058 level (not a lambda or closure).
5159 """
5260
@@ -56,20 +64,29 @@ def __init__(self, n_workers: int = -1):
5664
5765 n_workers = os .cpu_count () or 1
5866 super ().__init__ (n_workers )
67+ self ._mp_context = self ._select_context ()
68+ self ._use_fork = self ._mp_context .get_start_method () == "fork"
5969
60- def _distribute (self , func , params_batch ):
61- """Evaluate objective in parallel using multiprocessing.Pool.
70+ @staticmethod
71+ def _select_context ():
72+ available = multiprocessing .get_all_start_methods ()
73+ if "fork" in available :
74+ return multiprocessing .get_context ("fork" )
75+ return multiprocessing .get_context ()
6276
63- The function is stored in a module-level variable and inherited
64- by forked workers, so only the params dicts travel through the
65- serialization pipe.
66- """
77+ def _distribute (self , func , params_batch ):
6778 global _worker_func
68- _worker_func = func
69- try :
70- ctx = multiprocessing . get_context ( "fork" )
71- with ctx . Pool ( self . n_workers ) as pool :
72- scores = pool . map ( _eval_single , params_batch )
73- finally :
79+
80+ if self . _use_fork :
81+ # Set before Pool creation so forked workers inherit it
82+ _worker_func = func
83+ with self . _mp_context . Pool ( self . n_workers ) as pool :
84+ results = pool . map ( _eval_single , params_batch )
7485 _worker_func = None
75- return scores
86+ else :
87+ with self ._mp_context .Pool (self .n_workers ) as pool :
88+ results = pool .starmap (
89+ _eval_with_func ,
90+ [(func , p ) for p in params_batch ],
91+ )
92+ return results
0 commit comments