Skip to content

Commit 78041b9

Browse files
committed
dynamical grids according to size
1 parent a42a30f commit 78041b9

2 files changed

Lines changed: 116 additions & 18 deletions

File tree

examples/example_09_QRS_conf.py

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,68 @@ def get_visited_energies(qrs):
9191
return energies
9292

9393

94-
def plot_id_vs_energy(code, energies, match_ids=None, match_energies=None, out_dir="qrs_plots", time_cost_s=None, coverage=None, n_conformers=None):
94+
def select_delta_angle(molecules, composition=None):
95+
"""Return delta_angle(s) based on molecule sizes.
96+
97+
If `composition` is provided and has multiple components, return a list
98+
of delta angles (one per component). Otherwise return a single float.
99+
100+
Rules (per component):
101+
- max_atoms < 5 -> 60.0
102+
- 5 <= max_atoms < 10 -> 45.0
103+
- max_atoms < 30 -> 30.0
104+
- else -> 15.0
105+
"""
106+
# Default fallback
107+
DEFAULT = 15.0
108+
if not molecules:
109+
if composition is None or len(composition) <= 1:
110+
return DEFAULT
111+
return [DEFAULT for _ in composition]
112+
113+
# If composition provided and multi-component, compute per-component
114+
if composition is not None and len(composition) > 1:
115+
delta_list = []
116+
for idx, count in enumerate(composition):
117+
# guard when molecules list does not align with composition
118+
pool = molecules[idx] if idx < len(molecules) else None
119+
if not pool:
120+
delta_list.append(DEFAULT)
121+
continue
122+
mol = pool[0]
123+
try:
124+
atom_count = len(mol.mol)
125+
except Exception:
126+
atom_count = len(getattr(mol, "atoms", []))
127+
if atom_count <= 3:
128+
delta_list.append(90.0)
129+
elif atom_count < 5:
130+
delta_list.append(60.0)
131+
elif atom_count < 10:
132+
delta_list.append(45.0)
133+
elif atom_count < 30:
134+
delta_list.append(30.0)
135+
else:
136+
delta_list.append(15.0)
137+
return delta_list
138+
139+
# Single-component fallback
140+
pool = molecules[0] if isinstance(molecules, (list, tuple)) and molecules else None
141+
if not pool:
142+
return DEFAULT
143+
mol = pool[0]
144+
try:
145+
atom_count = len(mol.mol)
146+
except Exception:
147+
atom_count = len(getattr(mol, "atoms", []))
148+
if atom_count < 10:
149+
return 45.0
150+
if atom_count < 30:
151+
return 30.0
152+
return 15.0
153+
154+
155+
def plot_id_vs_energy(code, energies, match_ids=None, match_energies=None, out_dir="qrs_plots", time_cost_s=None, coverage=None, n_conformers=None, energy_unit="kcal/mol"):
95156
"""Save a plot of visited-structure ID vs energy for one QRS run."""
96157
if not energies:
97158
print(f"No energies collected for {code}; skipping plot.")
@@ -101,7 +162,7 @@ def plot_id_vs_energy(code, energies, match_ids=None, match_energies=None, out_d
101162
ids = list(range(1, len(energies) + 1))
102163

103164
fig, ax = plt.subplots(figsize=(7, 4.5))
104-
ax.scatter(ids, energies, s=20, alpha=0.8, label="Visited")
165+
ax.scatter(ids, energies, s=10, alpha=0.8, label="Visited")
105166
ax.plot(ids, energies, linewidth=0.8, alpha=0.6)
106167
if match_ids and match_energies:
107168
ax.scatter(
@@ -116,12 +177,12 @@ def plot_id_vs_energy(code, energies, match_ids=None, match_energies=None, out_d
116177
label="Match",
117178
)
118179
ax.set_xlabel("Visited Structure ID")
119-
ax.set_ylabel("Energy (kcal/mol)")
180+
ax.set_ylabel(f"Energy ({energy_unit})")
120181
title_parts = [f"{code}: "]
121182
if n_conformers is not None:
122183
title_parts.append(f"conf: {n_conformers}")
123184
if time_cost_s is not None:
124-
title_parts.append(f"time: {time_cost_s:.1f} s")
185+
title_parts.append(f"time: {time_cost_s:.2f} s")
125186
if coverage is not None:
126187
title_parts.append(f"coverage: {coverage}")
127188
if len(title_parts) > 1:
@@ -138,7 +199,7 @@ def plot_id_vs_energy(code, energies, match_ids=None, match_energies=None, out_d
138199
ax.set_ylim(ymin - margin, ymax + margin)
139200
else:
140201
y_max = min(ymin + 30, ymax)
141-
ax.set_ylim(ymin - 1, y_max)
202+
ax.set_ylim(ymin - 0.25, y_max)
142203
fig.tight_layout()
143204

144205
fig_path = os.path.join(out_dir, f"{code}.png")
@@ -173,8 +234,8 @@ def plot_id_vs_energy(code, energies, match_ids=None, match_energies=None, out_d
173234
for code in db.get_all_codes():
174235
#if code not in ['ACSALA']: continue
175236
#if code not in ['FUNZOE']: continue
176-
if code in ['ACEMID02']: continue
177-
#if code not in ['XAFPAY', 'OBEQIX', 'UJIRIO02']: continue
237+
if code not in ['XAFQON']: continue
238+
#if code not in ['ACEMID02']: continue
178239
row = db.get_row(code=code)
179240
ref_xtal = db.get_pyxtal(code=code)
180241
if ref_xtal.has_special_site():
@@ -278,24 +339,33 @@ def plot_id_vs_energy(code, energies, match_ids=None, match_energies=None, out_d
278339
param_xml = os.path.join(workdir, "parameters.xml")
279340
if os.path.exists(param_xml):
280341
os.remove(param_xml)
342+
print(f"Initialized QRS workdir: {workdir}", ref_xtal.lattice)
343+
composition = [int(a) for a in ref_xtal.get_zprime()]
344+
# determine per-component delta angles and show them
345+
selected_deltas = select_delta_angle(molecules, composition)
346+
print(f"Selected delta_angle(s) for components: {selected_deltas}")
347+
281348
qrs = QRS(
282349
smiles=row.mol_smi,
283350
workdir=workdir,
284351
sg=ref_xtal.group.hall_number,
285352
tag=row.csd_code.lower(),
286353
use_hall=True,
287354
lattice=ref_xtal.lattice, # Fixed cell.
288-
composition = [int(a) for a in ref_xtal.get_zprime()],
355+
composition = composition,
289356
molecules=molecules,
290357
sites=sites,
291-
N_gen=200,
292-
N_pop=96,
293-
N_cpu=48,
358+
N_gen=1, #00,
359+
N_pop=4, #8,
360+
N_cpu=1, #2,#4,
294361
cif="all.cif",
295362
skip_mlp=True,
363+
mlp='MACEOFF',
296364
verbose=False,
297365
delta_length=1.0,
298-
delta_angle=15.0,
366+
# Pass per-component delta_angle (scalar or list). QRS will
367+
# expand per-component values to per-site internally.
368+
delta_angle=selected_deltas,
299369
)
300370

301371
t0 = perf_counter()
@@ -322,6 +392,7 @@ def plot_id_vs_energy(code, energies, match_ids=None, match_energies=None, out_d
322392
time_cost_s=time_cost_s,
323393
coverage=coverage,
324394
n_conformers=n_pregen_total,
395+
energy_unit="eV/atom" if not qrs.skip_mlp else "kcal/mol",
325396
)
326397

327398
with open(csv_path, "a", newline="") as fcsv:

pyxtal/optimize/QRS.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -277,16 +277,33 @@ def compute_wp_resolutions(wp_bounds, cell_lengths, delta_length=1.0, delta_angl
277277
list[int]: number of levels per DOF (flat, same ordering as wp_bounds)
278278
"""
279279
n_levels = []
280-
for site_bounds in wp_bounds:
280+
281+
# delta_angle may be a scalar (float) or an iterable providing a per-site
282+
# angle resolution. If iterable and length matches wp_bounds, use the
283+
# corresponding per-site delta; otherwise broadcast the scalar value.
284+
if isinstance(delta_angle, (list, tuple, np.ndarray)):
285+
if len(delta_angle) == len(wp_bounds):
286+
per_site_delta = list(delta_angle)
287+
else:
288+
# Unexpected shape: fall back to first element or treat as scalar
289+
try:
290+
per_site_delta = [float(delta_angle[0])] * len(wp_bounds)
291+
except Exception:
292+
per_site_delta = [float(delta_angle)] * len(wp_bounds)
293+
else:
294+
per_site_delta = [float(delta_angle)] * len(wp_bounds)
295+
296+
for site_idx, site_bounds in enumerate(wp_bounds):
281297
coord_idx = 0
298+
site_delta = per_site_delta[site_idx]
282299
for (lb, ub) in site_bounds:
283300
span = ub - lb
284301
if abs(span - 1.0) < 1e-6: # fractional coordinate DOF
285302
edge = cell_lengths[coord_idx] if coord_idx < len(cell_lengths) else 1.0
286303
n = max(1, int(edge / delta_length))
287304
coord_idx += 1
288305
else: # angle DOF (Euler or torsion)
289-
n = max(1, int(round(span / delta_angle)))
306+
n = max(1, int(round(span / site_delta)))
290307
n_levels.append(n)
291308
return n_levels
292309

@@ -527,13 +544,23 @@ def _init_qrs_params(self):
527544
self.ltype = self.lattice.ltype
528545
self.wp_bounds = [site.get_bounds() for site in tmp.mol_sites]
529546
grid_wp_bounds = trim_wp_bounds_for_molecules(self.wp_bounds, self.composition, self.molecules)
530-
531-
if self.delta_length > 0 or self.delta_angle > 0:
547+
if self.delta_length > 0 or (isinstance(self.delta_angle, (int, float)) and self.delta_angle > 0) or (isinstance(self.delta_angle, (list, tuple)) and any([d > 0 for d in self.delta_angle])):
532548
# Uneven grid: per-dim resolution derived from cell lengths / angle range
533549
dl = self.delta_length if self.delta_length > 0 else 1.0
534-
da = self.delta_angle if self.delta_angle > 0 else 30.0
550+
da = self.delta_angle if self.delta_angle is not None else 30.0
535551
a, b, c = self.lattice.get_para()[:3]
536-
n_levels = compute_wp_resolutions(grid_wp_bounds, [a, b, c], dl, da)
552+
553+
# If da is a per-component list (one value per composition entry),
554+
# expand it into a per-site list matching grid_wp_bounds.
555+
if isinstance(da, (list, tuple)) and len(da) == len(self.composition):
556+
per_site_da = []
557+
for comp_idx, cnt in enumerate(self.composition):
558+
per_site_da.extend([da[comp_idx]] * int(cnt))
559+
else:
560+
per_site_da = da
561+
562+
n_levels = compute_wp_resolutions(grid_wp_bounds, [a, b, c], dl, per_site_da)
563+
print(f"Computed per-DOF grid levels: {n_levels}")
537564
self.sampler = GridSampler(n_levels)
538565
print(f"GridSampler initialised: {self.sampler}")
539566
else:

0 commit comments

Comments
 (0)