4444import logging
4545import multiprocessing
4646import platform
47+ import sys
4748import warnings
4849from functools import partial
4950from typing import Any
@@ -150,15 +151,77 @@ def get_config_for_mode(self, mode: DatasetMode) -> "MultiDatasetConfig":
150151 return self .get_subset (datasets_stage_mask )
151152
152153
154+
153155class DataModuleConfig (BaseModel ):
154156 datasets : list [SerializeAsAny [BaseModel ]]
155157 batch_size : int = 1
156158 num_workers : int = 0
157159 num_workers_validation : int = 0
158- multiprocessing_context : str = None
160+ multiprocessing_context : str = "openfold-default"
159161 data_seed : int = 42
160162 epoch_len : int = 1
161163
164+ @staticmethod
165+ def safe_multiprocessing_context (
166+ multiprocessing_context : str | None , num_workers : int
167+ ) -> str | None :
168+ """
169+ Returns multiprocessing start methods with safer/sensible defaults:
170+ - fork when using MPS
171+ - forkserver for linux, matching the new 3.14 default
172+ - default otherwise
173+
174+ For general info on risks and defaults across platformas and python versions see:
175+ https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
176+ https://docs.pytorch.org/docs/stable/notes/multiprocessing.html#multiprocessing-poison-fork-note
177+ https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
178+ """
179+
180+ # Do not bother if not using multiprocessing
181+ if num_workers > 0 :
182+
183+ # Set safe defaults
184+ if multiprocessing_context == "openfold-default" :
185+
186+ # Use fork to create processes when using MPS. See:
187+ # - https://github.com/pytorch/pytorch/issues/70344
188+ # - https://github.com/pytorch/pytorch/issues/87688
189+ if platform .system () == "Darwin" and torch .backends .mps .is_available ():
190+ return "fork"
191+
192+ # Use forkserver in linux
193+ # Backports the new python 3.14 default in previous python versions.
194+ # An alternative for further safety would be "spawn". Avoid "fork".
195+ # See: https://github.com/python/cpython/issues/84559
196+ if platform .system () == "linux" :
197+ return "forkserver"
198+
199+ # Use the platform default otherwise - "spawn" at the time of writing
200+ return multiprocessing .get_start_method ()
201+
202+ # Warn about unsafe defaults
203+ else :
204+ if platform .system () == "Darwin" and torch .backends .mps .is_available ():
205+ if multiprocessing_context != "fork" :
206+ logger .warning (
207+ f"Using multiprocessing context { multiprocessing_context } on MPS may cause "
208+ "issues. Consider using 'fork' or 'openfold-default' (which resolves to 'fork' on MPS)." ,
209+ stacklevel = 2 ,
210+ )
211+ if platform .system () == "linux" :
212+ dangerous_start_method = (
213+ multiprocessing_context == "fork" or
214+ multiprocessing_context is None and sys .version_info < (3 , 14 )
215+ )
216+ if dangerous_start_method :
217+ logger .warning (
218+ "Using 'fork' multiprocessing context in linux may cause issues. Consider using "
219+ "'spawn', 'forkserver' or 'openfold-default' (which resolves to 'forkserver' on linux)." ,
220+ stacklevel = 2 ,
221+ )
222+
223+ return multiprocessing_context
224+
162225
163226class DataModule (pl .LightningDataModule ):
164227 """A LightningDataModule class for organizing Datasets and DataLoaders."""
@@ -170,7 +233,7 @@ def __init__(self, data_module_config: DataModuleConfig) -> None:
170233 self .batch_size = data_module_config .batch_size
171234 self .num_workers = data_module_config .num_workers
172235 self .num_workers_validation = data_module_config .num_workers_validation
173- self .multiprocessing_context = data_module_config .multiprocessing_context
236+ self .multiprocessing_context = data_module_config .safe_multiprocessing_context
174237 self .data_seed = data_module_config .data_seed
175238 self .next_data_seed = data_module_config .data_seed
176239 self .epoch_len = data_module_config .epoch_len
@@ -438,22 +501,11 @@ def generate_dataloader(self, mode: DatasetMode, sampler: Sampler | None = None)
438501 # passed explicitly here.
439502 worker_init_fn = partial (pl_worker_init_function , rank = self .global_rank )
440503
441- # Configure multiprocessing_context with sensible defaults
442- # For general info on risks see:
443- # https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
444- # https://docs.pytorch.org/docs/stable/notes/multiprocessing.html#multiprocessing-poison-fork-note
445- multiprocessing_context = self .multiprocessing_context
446- # Use known/safer working defaults
447- if multiprocessing_context is None and num_workers > 0 :
448- # Use fork to create processes when using MPS
449- # See:
450- # - https://github.com/pytorch/pytorch/issues/70344
451- # - https://github.com/pytorch/pytorch/issues/87688
452- if platform .system () == "Darwin" and torch .backends .mps .is_available ():
453- multiprocessing_context = "fork"
454- # Use spawn by default in aarch64 as it is the safer bet (we observed failures with default)
455- elif platform .system () == "linux" and platform .machine () == "aarch64" :
456- multiprocessing_context = "spawn"
504+ # Set a sensible default for multiprocesssing start method
505+ # depending on platform and python version.
506+ multiprocessing_context = DataModuleConfig .safe_multiprocessing_context (
507+ self .multiprocessing_context , num_workers
508+ )
457509
458510 logger .debug (
459511 f"Creating { mode } dataloader: "
0 commit comments