Skip to content

Commit 4bf2b1b

Browse files
committed
allow all
1 parent 98106c4 commit 4bf2b1b

1 file changed

Lines changed: 66 additions & 5 deletions

File tree

examples/example_10_mlp_relax.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,56 @@ def worker_relax(args):
214214
return (label, label_idx, orig_energy, None, f'relax_exception: {e}', None, time.time() - start_time)
215215

216216

217+
def find_qrs_job_dirs(root):
218+
"""Return (input_dir, output_dir) pairs under root that contain QRS-openffall.cif."""
219+
root = os.path.abspath(root)
220+
direct_cif = os.path.join(root, 'QRS-openffall.cif')
221+
if os.path.isfile(direct_cif):
222+
return [(root, root)]
223+
224+
jobs = []
225+
for name in sorted(os.listdir(root)):
226+
sub = os.path.join(root, name)
227+
if os.path.isdir(sub) and os.path.isfile(os.path.join(sub, 'QRS-openffall.cif')):
228+
jobs.append((sub, sub))
229+
return jobs
230+
231+
232+
def collect_jobs(cif_path, out_dir=None):
233+
"""Resolve one or many QRS jobs from a CIF path and optional output directory."""
234+
cif_path = os.path.abspath(cif_path)
235+
236+
if os.path.isfile(cif_path):
237+
parent = os.path.dirname(cif_path)
238+
job_out = os.path.abspath(out_dir) if out_dir else parent
239+
return [(parent, job_out)]
240+
241+
if out_dir is not None:
242+
out_dir = os.path.abspath(out_dir)
243+
if os.path.isdir(out_dir) and not os.path.isfile(os.path.join(out_dir, 'QRS-openffall.cif')):
244+
jobs = []
245+
for sub, _ in find_qrs_job_dirs(out_dir):
246+
jobs.append((sub, sub))
247+
if jobs:
248+
return jobs
249+
250+
if os.path.isdir(cif_path):
251+
if not os.path.isfile(os.path.join(cif_path, 'QRS-openffall.cif')):
252+
jobs = find_qrs_job_dirs(cif_path)
253+
if jobs:
254+
if out_dir is not None:
255+
out_root = os.path.abspath(out_dir)
256+
return [
257+
(sub, os.path.join(out_root, os.path.basename(sub)))
258+
for sub, _ in jobs
259+
]
260+
return jobs
261+
job_out = out_dir if out_dir is not None else cif_path
262+
return [(cif_path, os.path.abspath(job_out))]
263+
264+
raise FileNotFoundError(f"No QRS-openffall.cif found in {cif_path}")
265+
266+
217267
def main(cif_path, nproc=4, step=200, fmax=0.1, out_dir=None, db_file=None, ref_code=None, matched_cif=None, cutoff_pct=50.0, e_tol=1e-3, calculator='MACE', model=None, quick=False):
218268
main_start = time.time()
219269
input_folder = None
@@ -516,7 +566,7 @@ def main(cif_path, nproc=4, step=200, fmax=0.1, out_dir=None, db_file=None, ref_
516566
pass
517567

518568
parser = argparse.ArgumentParser(description='Relax QRS CIF blocks with MACE and compare energies')
519-
parser.add_argument('cif', help='Path to QRS multi-block CIF or directory containing QRS-openffall.cif')
569+
parser.add_argument('cif', help='Path to QRS multi-block CIF, a case folder, or a parent folder of case folders')
520570
parser.add_argument('--nproc', type=int, default=4, help='Number of parallel workers')
521571
parser.add_argument('--pct', type=float, default=10.0, help='Energy percentile cutoff (e.g. 10 for 10%%)')
522572
parser.add_argument('--db', default="pyxtal/database/test.db", help='Path to test.db for reference matching')
@@ -535,17 +585,15 @@ def main(cif_path, nproc=4, step=200, fmax=0.1, out_dir=None, db_file=None, ref_
535585
action='store_true',
536586
help='Use a cheaper/faster model preset for quick testing (ORB: direct-20-omat, MACEOFF: small)',
537587
)
538-
parser.add_argument('--out-dir', default=None, help='Output directory (defaults to input folder)')
588+
parser.add_argument('--out-dir', default=None, help='Output directory (defaults to input folder; if set without QRS-openffall.cif, process all subdirs)')
539589
args = parser.parse_args()
540590
if args.quick and args.model is not None:
541591
parser.error('Use only one of --quick or --model')
542592

543-
main(
544-
args.cif,
593+
common_kwargs = dict(
545594
nproc=args.nproc,
546595
step=args.step,
547596
fmax=args.fmax,
548-
out_dir=args.out_dir,
549597
db_file=args.db,
550598
ref_code=None,
551599
matched_cif=args.matched_cif,
@@ -555,3 +603,16 @@ def main(cif_path, nproc=4, step=200, fmax=0.1, out_dir=None, db_file=None, ref_
555603
model=args.model,
556604
quick=args.quick,
557605
)
606+
607+
jobs = collect_jobs(args.cif, args.out_dir)
608+
if not jobs:
609+
parser.error(f'No QRS-openffall.cif found under {args.cif}')
610+
611+
for i, (job_dir, job_out_dir) in enumerate(jobs, start=1):
612+
if len(jobs) > 1:
613+
print(f'\n=== [{i}/{len(jobs)}] Processing {job_dir} ===')
614+
main(
615+
job_dir,
616+
out_dir=job_out_dir,
617+
**common_kwargs,
618+
)

0 commit comments

Comments
 (0)