@@ -157,6 +157,23 @@ def guess_ref_code_from_cif(path):
157157 return ref_code
158158
159159
160+ def _configure_worker_threads ():
161+ """Avoid CPU oversubscription when running multiple PyTorch workers."""
162+ os .environ .setdefault ("OMP_NUM_THREADS" , "1" )
163+ os .environ .setdefault ("MKL_NUM_THREADS" , "1" )
164+ try :
165+ import torch
166+ torch .set_num_threads (1 )
167+ except ImportError :
168+ pass
169+
170+
171+ def _init_worker (calculator , model , quick ):
172+ """Load the calculator once per worker process (safe with spawn)."""
173+ _configure_worker_threads ()
174+ get_calculator (calculator , model = model , quick = quick )
175+
176+
160177def worker_relax (args ):
161178 """Worker to relax a single CIF block.
162179 args: (label, block_text, orig_energy, step, fmax, label_idx, calculator, model, quick)
@@ -331,13 +348,19 @@ def main(cif_path, nproc=4, step=200, fmax=0.1, out_dir=None, db_file=None, ref_
331348 chunksize = max (1 , len (tasks ) // nproc )
332349 print (f'Running { len (tasks )} relaxation tasks for { len (unique_selected )} unique structures with calculator={ calculator } , model={ model } ' )
333350 print (f'Using chunksize={ chunksize } for { nproc } workers' )
334- print ('Preloading calculator...' )
335- get_calculator (calculator , model = model , quick = quick )
336- # fork shares the preloaded model across workers on macOS/Linux.
337- ctx = mp .get_context ('fork' )
338- pool = ctx .Pool (processes = nproc )
339- results = pool .map (worker_relax , tasks , chunksize = chunksize )
340- pool .close (); pool .join ()
351+ if nproc == 1 :
352+ _configure_worker_threads ()
353+ get_calculator (calculator , model = model , quick = quick )
354+ results = [worker_relax (task ) for task in tasks ]
355+ else :
356+ # spawn avoids macOS fork crashes after MPS/PyTorch init in the parent.
357+ ctx = mp .get_context ("spawn" )
358+ with ctx .Pool (
359+ processes = nproc ,
360+ initializer = _init_worker ,
361+ initargs = (calculator , model , quick ),
362+ ) as pool :
363+ results = pool .map (worker_relax , tasks , chunksize = chunksize )
341364 print (f'Completed { len (results )} relaxation results' )
342365 total_relax_time = sum ([r [6 ] for r in results if len (r ) > 6 and r [6 ] is not None ])
343366 print (f'Total relaxation wall time: { total_relax_time :.1f} s' )
@@ -486,8 +509,11 @@ def main(cif_path, nproc=4, step=200, fmax=0.1, out_dir=None, db_file=None, ref_
486509
487510
488511if __name__ == '__main__' :
489- # On macOS, use 'fork' method to preserve conda environment in worker processes
490- mp .set_start_method ('fork' , force = True )
512+ # spawn is required on macOS when PyTorch/MPS calculators are used with multiprocessing.
513+ try :
514+ mp .set_start_method ('spawn' , force = True )
515+ except RuntimeError :
516+ pass
491517
492518 parser = argparse .ArgumentParser (description = 'Relax QRS CIF blocks with MACE and compare energies' )
493519 parser .add_argument ('cif' , help = 'Path to QRS multi-block CIF or directory containing QRS-openffall.cif' )
0 commit comments