Skip to content

Commit 98106c4

Browse files
committed
add example
1 parent 1953142 commit 98106c4

1 file changed

Lines changed: 35 additions & 9 deletions

File tree

examples/example_10_mlp_relax.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,23 @@ def guess_ref_code_from_cif(path):
157157
return ref_code
158158

159159

160+
def _configure_worker_threads():
161+
"""Avoid CPU oversubscription when running multiple PyTorch workers."""
162+
os.environ.setdefault("OMP_NUM_THREADS", "1")
163+
os.environ.setdefault("MKL_NUM_THREADS", "1")
164+
try:
165+
import torch
166+
torch.set_num_threads(1)
167+
except ImportError:
168+
pass
169+
170+
171+
def _init_worker(calculator, model, quick):
172+
"""Load the calculator once per worker process (safe with spawn)."""
173+
_configure_worker_threads()
174+
get_calculator(calculator, model=model, quick=quick)
175+
176+
160177
def worker_relax(args):
161178
"""Worker to relax a single CIF block.
162179
args: (label, block_text, orig_energy, step, fmax, label_idx, calculator, model, quick)
@@ -331,13 +348,19 @@ def main(cif_path, nproc=4, step=200, fmax=0.1, out_dir=None, db_file=None, ref_
331348
chunksize = max(1, len(tasks) // nproc)
332349
print(f'Running {len(tasks)} relaxation tasks for {len(unique_selected)} unique structures with calculator={calculator}, model={model}')
333350
print(f'Using chunksize={chunksize} for {nproc} workers')
334-
print('Preloading calculator...')
335-
get_calculator(calculator, model=model, quick=quick)
336-
# fork shares the preloaded model across workers on macOS/Linux.
337-
ctx = mp.get_context('fork')
338-
pool = ctx.Pool(processes=nproc)
339-
results = pool.map(worker_relax, tasks, chunksize=chunksize)
340-
pool.close(); pool.join()
351+
if nproc == 1:
352+
_configure_worker_threads()
353+
get_calculator(calculator, model=model, quick=quick)
354+
results = [worker_relax(task) for task in tasks]
355+
else:
356+
# spawn avoids macOS fork crashes after MPS/PyTorch init in the parent.
357+
ctx = mp.get_context("spawn")
358+
with ctx.Pool(
359+
processes=nproc,
360+
initializer=_init_worker,
361+
initargs=(calculator, model, quick),
362+
) as pool:
363+
results = pool.map(worker_relax, tasks, chunksize=chunksize)
341364
print(f'Completed {len(results)} relaxation results')
342365
total_relax_time = sum([r[6] for r in results if len(r) > 6 and r[6] is not None])
343366
print(f'Total relaxation wall time: {total_relax_time:.1f} s')
@@ -486,8 +509,11 @@ def main(cif_path, nproc=4, step=200, fmax=0.1, out_dir=None, db_file=None, ref_
486509

487510

488511
if __name__ == '__main__':
489-
# On macOS, use 'fork' method to preserve conda environment in worker processes
490-
mp.set_start_method('fork', force=True)
512+
# spawn is required on macOS when PyTorch/MPS calculators are used with multiprocessing.
513+
try:
514+
mp.set_start_method('spawn', force=True)
515+
except RuntimeError:
516+
pass
491517

492518
parser = argparse.ArgumentParser(description='Relax QRS CIF blocks with MACE and compare energies')
493519
parser.add_argument('cif', help='Path to QRS multi-block CIF or directory containing QRS-openffall.cif')

0 commit comments

Comments
 (0)