Skip to content

Commit 1fb8bf2

Browse files
committed
Better defaults and logging for multiprocessing context
We would still need to document this somewhere
1 parent 2927924 commit 1fb8bf2

1 file changed

Lines changed: 70 additions & 18 deletions

File tree

openfold3/core/data/framework/data_module.py

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import logging
4545
import multiprocessing
4646
import platform
47+
import sys
4748
import warnings
4849
from functools import partial
4950
from typing import Any
@@ -150,15 +151,77 @@ def get_config_for_mode(self, mode: DatasetMode) -> "MultiDatasetConfig":
150151
return self.get_subset(datasets_stage_mask)
151152

152153

154+
153155
class DataModuleConfig(BaseModel):
154156
datasets: list[SerializeAsAny[BaseModel]]
155157
batch_size: int = 1
156158
num_workers: int = 0
157159
num_workers_validation: int = 0
158-
multiprocessing_context: str = None
160+
multiprocessing_context: str = "openfold-default"
159161
data_seed: int = 42
160162
epoch_len: int = 1
161163

164+
@staticmethod
165+
def safe_multiprocessing_context(
166+
multiprocessing_context: str | None, num_workers: int
167+
) -> str | None:
168+
"""
169+
Returns multiprocessing start methods with safer/sensible defaults:
170+
- fork when using MPS
171+
- forkserver for linux, matching the new 3.14 default
172+
- default otherwise
173+
174+
For general info on risks and defaults across platformas and python versions see:
175+
https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
176+
https://docs.pytorch.org/docs/stable/notes/multiprocessing.html#multiprocessing-poison-fork-note
177+
https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
178+
"""
179+
180+
# Do not bother if not using multiprocessing
181+
if num_workers > 0:
182+
183+
# Set safe defaults
184+
if multiprocessing_context == "openfold-default":
185+
186+
# Use fork to create processes when using MPS. See:
187+
# - https://github.com/pytorch/pytorch/issues/70344
188+
# - https://github.com/pytorch/pytorch/issues/87688
189+
if platform.system() == "Darwin" and torch.backends.mps.is_available():
190+
return "fork"
191+
192+
# Use forkserver in linux
193+
# Backports the new python 3.14 default in previous python versions.
194+
# An alternative for further safety would be "spawn". Avoid "fork".
195+
# See: https://github.com/python/cpython/issues/84559
196+
if platform.system() == "linux":
197+
return "forkserver"
198+
199+
# Use the platform default otherwise - "spawn" at the time of writing
200+
return multiprocessing.get_start_method()
201+
202+
# Warn about unsafe defaults
203+
else:
204+
if platform.system() == "Darwin" and torch.backends.mps.is_available():
205+
if multiprocessing_context != "fork":
206+
logger.warning(
207+
f"Using multiprocessing context {multiprocessing_context} on MPS may cause "
208+
"issues. Consider using 'fork' or 'openfold-default' (which resolves to 'fork' on MPS).",
209+
stacklevel=2,
210+
)
211+
if platform.system() == "linux":
212+
dangerous_start_method = (
213+
multiprocessing_context == "fork" or
214+
multiprocessing_context is None and sys.version_info < (3, 14)
215+
)
216+
if dangerous_start_method:
217+
logger.warning(
218+
"Using 'fork' multiprocessing context in linux may cause issues. Consider using "
219+
"'spawn', 'forkserver' or 'openfold-default' (which resolves to 'forkserver' on linux).",
220+
stacklevel=2,
221+
)
222+
223+
return multiprocessing_context
224+
162225

163226
class DataModule(pl.LightningDataModule):
164227
"""A LightningDataModule class for organizing Datasets and DataLoaders."""
@@ -170,7 +233,7 @@ def __init__(self, data_module_config: DataModuleConfig) -> None:
170233
self.batch_size = data_module_config.batch_size
171234
self.num_workers = data_module_config.num_workers
172235
self.num_workers_validation = data_module_config.num_workers_validation
173-
self.multiprocessing_context = data_module_config.multiprocessing_context
236+
self.multiprocessing_context = data_module_config.safe_multiprocessing_context
174237
self.data_seed = data_module_config.data_seed
175238
self.next_data_seed = data_module_config.data_seed
176239
self.epoch_len = data_module_config.epoch_len
@@ -438,22 +501,11 @@ def generate_dataloader(self, mode: DatasetMode, sampler: Sampler | None = None)
438501
# passed explicitly here.
439502
worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank)
440503

441-
# Configure multiprocessing_context with sensible defaults
442-
# For general info on risks see:
443-
# https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
444-
# https://docs.pytorch.org/docs/stable/notes/multiprocessing.html#multiprocessing-poison-fork-note
445-
multiprocessing_context = self.multiprocessing_context
446-
# Use known/safer working defaults
447-
if multiprocessing_context is None and num_workers > 0:
448-
# Use fork to create processes when using MPS
449-
# See:
450-
# - https://github.com/pytorch/pytorch/issues/70344
451-
# - https://github.com/pytorch/pytorch/issues/87688
452-
if platform.system() == "Darwin" and torch.backends.mps.is_available():
453-
multiprocessing_context = "fork"
454-
# Use spawn by default in aarch64 as it is the safer bet (we observed failures with default)
455-
elif platform.system() == "linux" and platform.machine() == "aarch64":
456-
multiprocessing_context = "spawn"
504+
# Set a sensible default for multiprocesssing start method
505+
# depending on platform and python version.
506+
multiprocessing_context = DataModuleConfig.safe_multiprocessing_context(
507+
self.multiprocessing_context, num_workers
508+
)
457509

458510
logger.debug(
459511
f"Creating {mode} dataloader: "

0 commit comments

Comments
 (0)