Skip to content

Commit ac7610a

Browse files
committed
replace the closure in wrap_with_catch with a callable class
1 parent 1820d03 commit ac7610a

1 file changed

Lines changed: 22 additions & 11 deletions

File tree

src/gradient_free_optimizers/_catch.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,23 @@
66
logger = logging.getLogger(__name__)
77

88

9-
def wrap_with_catch(
10-
objective_function: Callable,
11-
catch: dict[type[Exception], int | float],
12-
) -> Callable:
13-
"""Wrap objective function to catch exceptions and return fallback scores."""
14-
catch_types = tuple(catch.keys())
9+
class _CatchWrapper:
10+
"""Callable that catches exceptions from the objective and returns fallbacks.
11+
12+
A class instead of a closure so it can be pickled by multiprocessing's
13+
spawn context (Windows, macOS 3.14+).
14+
"""
15+
16+
def __init__(self, func: Callable, catch: dict[type[Exception], int | float]):
17+
self.func = func
18+
self.catch = catch
19+
self._catch_types = tuple(catch.keys())
1520

16-
def wrapped(params):
21+
def __call__(self, params):
1722
try:
18-
return objective_function(params)
19-
except catch_types as e:
20-
for exc_type, fallback_score in catch.items():
23+
return self.func(params)
24+
except self._catch_types as e:
25+
for exc_type, fallback_score in self.catch.items():
2126
if isinstance(e, exc_type):
2227
logger.warning(
2328
"Caught %s in objective function: %s. "
@@ -29,4 +34,10 @@ def wrapped(params):
2934
return fallback_score
3035
raise
3136

32-
return wrapped
37+
38+
def wrap_with_catch(
39+
objective_function: Callable,
40+
catch: dict[type[Exception], int | float],
41+
) -> Callable:
42+
"""Wrap objective function to catch exceptions and return fallback scores."""
43+
return _CatchWrapper(objective_function, catch)

0 commit comments

Comments
 (0)