Skip to content

Commit f91f5e8

Browse files
committed
Merge branch 'master' of https://github.com/qzhu2017/PyXtal
2 parents 0ba7eb7 + 77eacad commit f91f5e8

4 files changed

Lines changed: 60 additions & 45 deletions

File tree

examples/example_09_QRS_conf.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def get_visited_energies(qrs):
9191
return energies
9292

9393

94-
def plot_id_vs_energy(code, energies, match_ids=None, match_energies=None, out_dir="qrs_plots", time_cost_s=None, coverage=None, n_conformers=None):
94+
def plot_id_vs_energy(code, energies, match_ids=None, match_energies=None, out_dir="qrs_plots", time_cost_s=None, coverage=None, n_conformers=None, energy_unit="kcal/mol"):
9595
"""Save a plot of visited-structure ID vs energy for one QRS run."""
9696
if not energies:
9797
print(f"No energies collected for {code}; skipping plot.")
@@ -101,7 +101,7 @@ def plot_id_vs_energy(code, energies, match_ids=None, match_energies=None, out_d
101101
ids = list(range(1, len(energies) + 1))
102102

103103
fig, ax = plt.subplots(figsize=(7, 4.5))
104-
ax.scatter(ids, energies, s=20, alpha=0.8, label="Visited")
104+
ax.scatter(ids, energies, s=10, alpha=0.8, label="Visited")
105105
ax.plot(ids, energies, linewidth=0.8, alpha=0.6)
106106
if match_ids and match_energies:
107107
ax.scatter(
@@ -116,12 +116,12 @@ def plot_id_vs_energy(code, energies, match_ids=None, match_energies=None, out_d
116116
label="Match",
117117
)
118118
ax.set_xlabel("Visited Structure ID")
119-
ax.set_ylabel("Energy (kcal/mol)")
119+
ax.set_ylabel(f"Energy ({energy_unit})")
120120
title_parts = [f"{code}: "]
121121
if n_conformers is not None:
122122
title_parts.append(f"conf: {n_conformers}")
123123
if time_cost_s is not None:
124-
title_parts.append(f"time: {time_cost_s:.1f} s")
124+
title_parts.append(f"time: {time_cost_s:.2f} s")
125125
if coverage is not None:
126126
title_parts.append(f"coverage: {coverage}")
127127
if len(title_parts) > 1:
@@ -138,7 +138,7 @@ def plot_id_vs_energy(code, energies, match_ids=None, match_energies=None, out_d
138138
ax.set_ylim(ymin - margin, ymax + margin)
139139
else:
140140
y_max = min(ymin + 30, ymax)
141-
ax.set_ylim(ymin - 1, y_max)
141+
ax.set_ylim(ymin - 0.25, y_max)
142142
fig.tight_layout()
143143

144144
fig_path = os.path.join(out_dir, f"{code}.png")
@@ -173,8 +173,7 @@ def plot_id_vs_energy(code, energies, match_ids=None, match_energies=None, out_d
173173
for code in db.get_all_codes():
174174
#if code not in ['ACSALA']: continue
175175
#if code not in ['FUNZOE']: continue
176-
if code in ['ACEMID02']: continue
177-
#if code not in ['XAFPAY', 'OBEQIX', 'UJIRIO02']: continue
176+
if code not in ['XAFQON']: continue
178177
row = db.get_row(code=code)
179178
ref_xtal = db.get_pyxtal(code=code)
180179
if ref_xtal.has_special_site():
@@ -288,11 +287,12 @@ def plot_id_vs_energy(code, energies, match_ids=None, match_energies=None, out_d
288287
composition = [int(a) for a in ref_xtal.get_zprime()],
289288
molecules=molecules,
290289
sites=sites,
291-
N_gen=200,
292-
N_pop=96,
293-
N_cpu=48,
290+
N_gen=100,
291+
N_pop=48,
292+
N_cpu=2,#4,
294293
cif="all.cif",
295-
skip_mlp=True,
294+
skip_mlp=False,#True,
295+
mlp='MACEOFF',
296296
verbose=False,
297297
delta_length=1.0,
298298
delta_angle=15.0,
@@ -322,6 +322,7 @@ def plot_id_vs_energy(code, energies, match_ids=None, match_energies=None, out_d
322322
time_cost_s=time_cost_s,
323323
coverage=coverage,
324324
n_conformers=n_pregen_total,
325+
energy_unit="eV/atom" if not qrs.skip_mlp else "kcal/mol",
325326
)
326327

327328
with open(csv_path, "a", newline="") as fcsv:

pyxtal/interface/ase_opt.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1-
# --- Optional fix for PyTorch weights_only change (harmless if unused) ---
2-
from torch.serialization import add_safe_globals
3-
add_safe_globals([slice])
4-
# ------------------------------------------------------------------------
1+
# --- Temporary fix for PyTorch 2.6 weights_only change ---
2+
import torch as _torch
3+
import functools as _functools
4+
_original_torch_load = _torch.load
5+
@_functools.wraps(_original_torch_load)
6+
def _patched_torch_load(*args, **kwargs):
7+
kwargs.setdefault('weights_only', False)
8+
return _original_torch_load(*args, **kwargs)
9+
_torch.load = _patched_torch_load
10+
# ----------------------------------------------------------
511

612
import signal
713
import numpy as np
14+
import os
15+
import contextlib
816
from ase.constraints import FixSymmetry
917
from ase.filters import UnitCellFilter
1018
from ase.optimize.fire import FIRE
@@ -15,6 +23,13 @@
1523
_cached_uma = None
1624
_cached_ani = None
1725

26+
@contextlib.contextmanager
27+
def _suppress_stdout_stderr():
28+
"""Redirect stdout and stderr to /dev/null."""
29+
with open(os.devnull, 'w') as devnull:
30+
with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):
31+
yield
32+
1833
def get_calculator(calculator):
1934
"""
2035
Return an ASE calculator instance.
@@ -28,28 +43,32 @@ def get_calculator(calculator):
2843
if isinstance(calculator, str):
2944
if calculator == "UMA":
3045
if _cached_uma is None:
31-
from fairchem.core import pretrained_mlip, FAIRChemCalculator
32-
predictor = pretrained_mlip.get_predict_unit("uma-s-1p1")
33-
_cached_uma = FAIRChemCalculator(predictor,
34-
task_name="omc")
46+
with _suppress_stdout_stderr():
47+
from fairchem.core import pretrained_mlip, FAIRChemCalculator
48+
predictor = pretrained_mlip.get_predict_unit("uma-s-1p1")
49+
_cached_uma = FAIRChemCalculator(predictor,
50+
task_name="omc")
3551
calc = _cached_uma
3652

3753
elif calculator == "ANI":
3854
if _cached_ani is None:
39-
import torchani
40-
_cached_ani = torchani.models.ANI2x().ase()
55+
with _suppress_stdout_stderr():
56+
import torchani
57+
_cached_ani = torchani.models.ANI2x().ase()
4158
calc = _cached_ani
4259

4360
elif calculator == "MACE":
4461
if _cached_mace is None:
45-
from mace.calculators import mace_mp
46-
_cached_mace = mace_mp(model="small", dispersion=True)
62+
with _suppress_stdout_stderr():
63+
from mace.calculators import mace_mp
64+
_cached_mace = mace_mp(model="small", dispersion=True)
4765
calc = _cached_mace
4866

4967
elif calculator == "MACEOFF":
5068
if _cached_mace is None:
51-
from mace.calculators import mace_off
52-
_cached_mace = mace_off(model="medium")#, device="cpu")
69+
with _suppress_stdout_stderr():
70+
from mace.calculators import mace_off
71+
_cached_mace = mace_off(model="medium")#, device="cpu")
5372
calc = _cached_mace
5473

5574
else:

pyxtal/optimize/base.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -991,38 +991,33 @@ def local_optimization_mproc(self, xtals, ncpu, ids=None, qrs=False, pool=None):
991991
if ids is None:
992992
ids = range(len(xtals))
993993

994-
N_cycle = int(np.ceil(len(xtals) / ncpu))
995-
# Generator to create arg_lists for multiprocessing tasks
994+
# Interleaved assignment: worker i handles indices i, i+ncpu, i+2*ncpu, …
995+
# so all workers stay busy and results can be sorted to 0,1,2,3… order.
996996
def generate_args_lists():
997997
for i in range(ncpu):
998-
id1 = i * N_cycle
999-
id2 = min([id1 + N_cycle, len(xtals)])
1000-
_ids = ids[id1: id2]
998+
_indices = list(range(i, len(xtals), ncpu))
999+
_ids = [ids[j] for j in _indices]
10011000
job_tags = [self.tag + "-g" + str(gen)
10021001
+ "-p" + str(id) for id in _ids]
1003-
_xtals = [xtals[id][0] for id in range(id1, id2)]
1002+
_xtals = [xtals[j][0] for j in _indices]
10041003
mutates = []
10051004
labels = []
1006-
for i in range(id1, id2):
1007-
orig_tag = xtals[i][1]
1005+
for j in _indices:
1006+
orig_tag = xtals[j][1]
10081007
if qrs:
10091008
mutates.append(False)
10101009
labels.append(orig_tag if orig_tag != "Random" else None)
10111010
else:
1012-
if orig_tag == "Mutation":
1013-
mutates.append(True)
1014-
else:
1015-
mutates.append(False)
1011+
mutates.append(orig_tag == "Mutation")
10161012
labels.append(None)
10171013
my_args = [_xtals, _ids, mutates, job_tags, labels, *args, self.rank, self.timeout]
1018-
yield tuple(my_args) # Yield args instead of appending to a list
1014+
yield tuple(my_args)
10191015

10201016
gen_results = []
10211017
for result in pool.imap_unordered(process_task, generate_args_lists()):
10221018
if result is not None:
10231019
for _res in result:
10241020
gen_results.append(_res)
1025-
10261021
return gen_results
10271022

10281023
def gen_summary(self, t0, gen_results, xtals):

pyxtal/optimize/common.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,8 @@ def optimizer(
424424
and 0.25 < struc.get_density() < 3.0
425425
):
426426
s = struc.to_ase()
427-
step = 50 if mlp in ['MACE', 'ANI'] else 10
428-
s = ASE_relax(s, mlp, step=step, fmax=0.1, logfile="ase.log")
427+
step = 50 if mlp in ['MACE', 'ANI'] else 25
428+
s = ASE_relax(s, mlp, opt_lat=opt_lat, step=step, fmax=0.1, logfile="ase.log")
429429
if s is None: return None
430430
eng = s.get_potential_energy()
431431
stress = max(abs(s.get_stress())) / units.GPa
@@ -662,7 +662,7 @@ def optimizer_single(
662662
strs += f" {match:.3f}"
663663

664664
xtal.energy = eng
665-
print(f"{id:3d} " + strs)#; import sys; sys.exit()
665+
print(f"{id:3d} " + strs)
666666
return xtal, match, stable
667667
else:
668668
return None, match, stable
@@ -677,9 +677,9 @@ def refine_struc(xtal, smiles, calculator, mlp):
677677
calculator: ANI_relax or MACE_relax
678678
"""
679679
s = xtal.to_ase()
680-
s = calculator(s, mlp, step=50, fmax=0.1, logfile="ase.log")
681-
s = calculator(s, mlp, step=250, opt_cell=True, logfile="ase.log")
682-
s = calculator(s, mlp, step=50, fmax=0.1, logfile="ase.log")
680+
s = calculator(s, mlp, opt_lat=False, step=50, fmax=0.1, logfile="ase.log")
681+
s = calculator(s, mlp, opt_lat=True, step=250, logfile="ase.log")
682+
s = calculator(s, mlp, opt_lat=False, step=50, fmax=0.1, logfile="ase.log")
683683
eng1 = s.get_potential_energy() # /sum(xtal.numMols)
684684

685685
xtal = pyxtal(molecular=True)

0 commit comments

Comments
 (0)