Skip to content

Commit 1820d03

Browse files
committed
attempt to fix windows specific error
1 parent 63942ff commit 1820d03

File tree

1 file changed

+38
-21
lines changed

1 file changed

+38
-21
lines changed
Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Multiprocessing backend for distributed evaluation.
22
3-
Uses a module-level variable to pass the objective function to worker
4-
processes, avoiding pickling issues. Only the parameter dicts (plain
5-
dicts of floats/strings) are serialized over the pool's pipe.
3+
On platforms with ``fork`` (Linux, older macOS), uses a module-level
4+
variable to pass the objective function to workers via shared memory,
5+
avoiding pickling entirely. On ``spawn`` platforms (Windows, macOS
6+
3.14+), the function is pickled alongside each parameter dict using
7+
``starmap``, which requires it to be importable at module level.
68
"""
79

810
from __future__ import annotations
@@ -12,15 +14,20 @@
1214
from ._base import BaseDistribution
1315

1416
# Shared between parent and forked workers via copy-on-write memory.
15-
# Set by _distribute() before pool.map, cleared after.
17+
# Only used with the fork context.
1618
_worker_func = None
1719

1820

1921
def _eval_single(params):
20-
"""Evaluate a single parameter dict in a worker process."""
22+
"""Worker entry point for fork context."""
2123
return _worker_func(params)
2224

2325

26+
def _eval_with_func(func, params):
27+
"""Worker entry point for spawn context (func is pickled per call)."""
28+
return func(params)
29+
30+
2431
class Multiprocessing(BaseDistribution):
2532
"""Distribute evaluations across local processes via multiprocessing.Pool.
2633
@@ -44,9 +51,10 @@ def objective(para):
4451
4552
Notes
4653
-----
47-
Uses the ``fork`` multiprocessing context on Linux/macOS so that
48-
the objective function does not need to be picklable. On systems
49-
where ``fork`` is unavailable the function must be defined at module
54+
Prefers the ``fork`` context (Linux/macOS) so the objective function
55+
is inherited by workers without pickling. Falls back to the platform
56+
default (``spawn`` on Windows, macOS 3.14+) where ``fork`` is
57+
unavailable. With ``spawn``, the objective must be defined at module
5058
level (not a lambda or closure).
5159
"""
5260

@@ -56,20 +64,29 @@ def __init__(self, n_workers: int = -1):
5664

5765
n_workers = os.cpu_count() or 1
5866
super().__init__(n_workers)
67+
self._mp_context = self._select_context()
68+
self._use_fork = self._mp_context.get_start_method() == "fork"
5969

60-
def _distribute(self, func, params_batch):
61-
"""Evaluate objective in parallel using multiprocessing.Pool.
70+
@staticmethod
71+
def _select_context():
72+
available = multiprocessing.get_all_start_methods()
73+
if "fork" in available:
74+
return multiprocessing.get_context("fork")
75+
return multiprocessing.get_context()
6276

63-
The function is stored in a module-level variable and inherited
64-
by forked workers, so only the params dicts travel through the
65-
serialization pipe.
66-
"""
77+
def _distribute(self, func, params_batch):
6778
global _worker_func
68-
_worker_func = func
69-
try:
70-
ctx = multiprocessing.get_context("fork")
71-
with ctx.Pool(self.n_workers) as pool:
72-
scores = pool.map(_eval_single, params_batch)
73-
finally:
79+
80+
if self._use_fork:
81+
# Set before Pool creation so forked workers inherit it
82+
_worker_func = func
83+
with self._mp_context.Pool(self.n_workers) as pool:
84+
results = pool.map(_eval_single, params_batch)
7485
_worker_func = None
75-
return scores
86+
else:
87+
with self._mp_context.Pool(self.n_workers) as pool:
88+
results = pool.starmap(
89+
_eval_with_func,
90+
[(func, p) for p in params_batch],
91+
)
92+
return results

0 commit comments

Comments
 (0)