Skip to content

Commit 9e21f1f

Browse files
authored
Merge pull request #425 from anyangml2nd/feat/support-rxn39
Feat: support rxn39
2 parents ddf2bd2 + 6cbf4b5 commit 9e21f1f

5 files changed

Lines changed: 123 additions & 0 deletions

File tree

lambench/metrics/downstream_tasks_metrics.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ rxn_barrier:
3232
domain: Molecules
3333
metrics: [MAE]
3434
dummy: {"MAE": 20.975}
35+
rxn_path39:
36+
domain: Molecules
37+
metrics: [MAE] # RMSE is not used for calculating metrics
38+
dummy: {"MAE": 34.109} # "RMSE": 43.150
3539
pressure:
3640
domain: Inorganic Materials
3741
metrics: [MAE]

lambench/metrics/post_process.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def process_domain_specific_for_one_model(model: BaseLargeAtomModel):
120120
"vacancy",
121121
"binding_energy",
122122
"rxn_barrier",
123+
"rxn_path39",
123124
"pressure",
124125
"stacking_fault",
125126
"interface",

lambench/models/ase_models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,18 @@ def evaluate(
280280
elif task.task_name == "wiggle150":
281281
from lambench.tasks.calculator.wiggle150.wiggle150 import run_inference
282282

283+
assert task.test_data is not None
284+
return {
285+
"metrics": run_inference(
286+
self,
287+
task.test_data,
288+
)
289+
}
290+
elif task.task_name == "rxn_path39":
291+
from lambench.tasks.calculator.rxn_path39.rxn_path39 import (
292+
run_inference,
293+
)
294+
283295
assert task.test_data is not None
284296
return {
285297
"metrics": run_inference(

lambench/tasks/calculator/calculator_tasks.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ neb:
2323
wiggle150:
2424
test_data: /bohr/lambench-wiggle150-yazy/v1/Wiggle150.traj
2525
calculator_params: null
26+
rxn_path39:
27+
test_data: /bohr/lambench-rxn39-755z/v2/trajs
28+
calculator_params: null
2629
elastic:
2730
test_data: /bohr/lambench-elastic-9qdt/v1/elastic.json
2831
calculator_params:
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""
2+
RXN-Path-39: 13 organic reactions (wB97M-V/def2-TZVPD), 3 path-sampling
3+
trajectories each, 11 arc-length-equidistant frames per trajectory.
4+
5+
For each trajectory the first frame (index 0) is chosen as the reference.
6+
The task measures how accurately a LAM reproduces the relative energies of all
7+
other frames with respect to that reference, i.e.
8+
9+
ΔE_DFT(i) = E_DFT(i) − E_DFT(frame 0) [kcal/mol]
10+
ΔE_LAM(i) = E_LAM(i) − E_LAM(frame 0) [kcal/mol]
11+
12+
and reports MAE and RMSE over all 39 × 10 = 390 (reaction, frame) pairs.
13+
"""
14+
15+
from pathlib import Path
16+
import logging
17+
18+
import numpy as np
19+
from ase.io import Trajectory
20+
from sklearn.metrics import mean_absolute_error, root_mean_squared_error
21+
22+
from lambench.models.ase_models import ASEModel
23+
24+
EV_TO_KCAL = 23.0609 # 1 eV = 23.0609 kcal/mol
25+
26+
27+
def run_inference(model: ASEModel, test_data: Path) -> dict[str, float]:
28+
"""
29+
Parameters
30+
----------
31+
model : ASEModel
32+
test_data : Path
33+
Root of the trajectory tree. Expected layout::
34+
35+
test_data/
36+
<reaction_id>/
37+
traj_0.traj
38+
traj_1.traj
39+
traj_2.traj
40+
...
41+
42+
Returns
43+
-------
44+
dict with keys "MAE" and "RMSE" in kcal/mol.
45+
"""
46+
calc = model.calc
47+
label_diffs: list[float] = []
48+
pred_diffs: list[float] = []
49+
50+
traj_files = sorted(test_data.rglob("traj_*.traj"))
51+
if not traj_files:
52+
raise FileNotFoundError(f"No traj_*.traj files found under {test_data}")
53+
54+
for traj_path in traj_files:
55+
frames = list(Trajectory(traj_path))
56+
57+
# DFT reference energies (eV, stored by SinglePointCalculator)
58+
dft_energies = np.array([a.get_potential_energy() for a in frames])
59+
ref_dft_kcal = dft_energies[0] * EV_TO_KCAL
60+
61+
# LAM energy for the first frame (reference)
62+
frames[0].calc = calc
63+
try:
64+
ref_pred_kcal = frames[0].get_potential_energy() * EV_TO_KCAL
65+
except Exception as e:
66+
logging.error(
67+
f"Failed predicting reference frame (idx=0) in {traj_path}: {e}"
68+
)
69+
continue # skip this trajectory entirely
70+
71+
# Relative energies for every non-reference frame
72+
for i, atoms in enumerate(frames):
73+
if i == 0:
74+
continue
75+
76+
label_diffs.append(dft_energies[i] * EV_TO_KCAL - ref_dft_kcal)
77+
78+
atoms.calc = calc
79+
try:
80+
pred_kcal = atoms.get_potential_energy() * EV_TO_KCAL
81+
except Exception as e:
82+
logging.error(f"Failed predicting frame {i} of {traj_path}: {e}")
83+
pred_kcal = np.nan
84+
pred_diffs.append(pred_kcal - ref_pred_kcal)
85+
86+
label_arr = np.array(label_diffs)
87+
pred_arr = np.array(pred_diffs)
88+
valid = np.isfinite(pred_arr)
89+
90+
if not valid.any():
91+
logging.error("All predictions failed; returning NaN metrics.")
92+
return {"MAE": np.nan, "RMSE": np.nan}
93+
94+
if not valid.all():
95+
n_failed = int((~valid).sum())
96+
logging.warning(
97+
f"{n_failed} frame(s) failed inference and were excluded from metrics."
98+
)
99+
100+
return {
101+
"MAE": float(mean_absolute_error(label_arr[valid], pred_arr[valid])),
102+
"RMSE": float(root_mean_squared_error(label_arr[valid], pred_arr[valid])),
103+
}

0 commit comments

Comments
 (0)