Skip to content

Commit e157ed7

Browse files
committed
feat: Implement XAS energy normalization in the XAS loss function and introduce a dedicated XAS model.
1 parent c8a4005 commit e157ed7

7 files changed

Lines changed: 315 additions & 35 deletions

File tree

deepmd/entrypoints/test.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -887,17 +887,20 @@ def test_property(
887887
high_prec=True,
888888
)
889889

890+
is_xas = var_name == "xas"
891+
890892
if dp.get_dim_fparam() > 0:
891893
data.add(
892894
"fparam", dp.get_dim_fparam(), atomic=False, must=True, high_prec=False
893895
)
894896
if dp.get_dim_aparam() > 0:
895897
data.add("aparam", dp.get_dim_aparam(), atomic=True, must=True, high_prec=False)
896898

897-
# sel_type: optional per-frame type index for element-wise mean reduction (XAS)
898-
data.add(
899-
"sel_type", 1, atomic=False, must=False, high_prec=False, default=float(-1)
900-
)
899+
# XAS requires sel_type.npy (per-frame absorbing element type index)
900+
if is_xas:
901+
data.add(
902+
"sel_type", 1, atomic=False, must=True, high_prec=False
903+
)
901904

902905
test_data = data.get_test()
903906
mixed_type = data.mixed_type
@@ -923,12 +926,8 @@ def test_property(
923926
else:
924927
aparam = None
925928

926-
# detect whether this system provides sel_type (XAS-style reduction)
927-
sel_type_raw = test_data["sel_type"][:numb_test, 0] # [numb_test]
928-
has_sel_type = bool((sel_type_raw >= 0).all())
929-
930-
# for sel_type reduction we need per-atom outputs
931-
eval_atomic = has_atom_property or has_sel_type
929+
# XAS: per-atom outputs are needed to average over absorbing-element atoms
930+
eval_atomic = has_atom_property or is_xas
932931
ret = dp.eval(
933932
coord,
934933
box,
@@ -939,27 +938,44 @@ def test_property(
939938
mixed_type=mixed_type,
940939
)
941940

942-
if has_sel_type:
941+
if is_xas:
943942
# ret[1]: per-atom property [numb_test, natoms, task_dim]
944943
atom_prop = ret[1].reshape([numb_test, natoms, dp.task_dim])
945-
# atype for all frames
946944
if mixed_type:
947945
atype_frames = atype # [numb_test, natoms]
948946
else:
949947
atype_frames = np.tile(atype, (numb_test, 1)) # [numb_test, natoms]
950-
sel_type_int = sel_type_raw.astype(int)
948+
sel_type_int = test_data["sel_type"][:numb_test, 0].astype(int)
951949
property = np.zeros([numb_test, dp.task_dim], dtype=atom_prop.dtype)
952950
for i in range(numb_test):
953951
t = sel_type_int[i]
954952
mask = atype_frames[i] == t # [natoms]
955953
count = max(mask.sum(), 1)
956954
property[i] = atom_prop[i][mask].sum(axis=0) / count
955+
956+
# Add back the per-(type, edge) energy reference so output is in
957+
# absolute eV (matching label format). xas_e_ref is saved in the
958+
# model checkpoint by XASLoss.compute_output_stats.
959+
try:
960+
xas_e_ref = dp.dp.model["Default"].atomic_model.xas_e_ref
961+
except AttributeError:
962+
xas_e_ref = None
963+
if xas_e_ref is not None and fparam is not None:
964+
import torch as _torch
965+
edge_idx_all = _torch.tensor(
966+
fparam.reshape(numb_test, -1)
967+
).argmax(dim=-1).numpy()
968+
e_ref_np = xas_e_ref.cpu().numpy() # [ntypes, nfparam, 2]
969+
for i in range(numb_test):
970+
t = sel_type_int[i]
971+
e = int(edge_idx_all[i])
972+
property[i, :2] += e_ref_np[t, e]
957973
else:
958974
property = ret[0]
959975

960976
property = property.reshape([numb_test, dp.task_dim])
961977

962-
if has_atom_property:
978+
if has_atom_property and not is_xas:
963979
aproperty = ret[1]
964980
aproperty = aproperty.reshape([numb_test, natoms * dp.task_dim])
965981

deepmd/pt/loss/xas.py

Lines changed: 200 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,15 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import logging
3-
from typing import (
4-
Any,
5-
)
3+
from collections import defaultdict
4+
from typing import Any
65

6+
import numpy as np
77
import torch
88
import torch.nn.functional as F
99

10-
from deepmd.pt.loss.loss import (
11-
TaskLoss,
12-
)
13-
from deepmd.pt.utils import (
14-
env,
15-
)
16-
from deepmd.utils.data import (
17-
DataRequirementItem,
18-
)
10+
from deepmd.pt.loss.loss import TaskLoss
11+
from deepmd.pt.utils import env
12+
from deepmd.utils.data import DataRequirementItem
1913

2014
log = logging.getLogger(__name__)
2115

@@ -28,10 +22,31 @@ class XASLoss(TaskLoss):
2822
in each training system) and takes their mean, then computes a loss against
2923
the per-frame XAS label.
3024
25+
Energy normalization
26+
--------------------
27+
XAS labels contain absolute edge energies (E_min, E_max in eV) that vary
28+
enormously across element-edge pairs (H_K ~14 eV, Th_K ~110000 eV).
29+
Training directly on absolute values causes gradient instability because
30+
the energy dimensions dwarf the intensity dimensions.
31+
32+
``compute_output_stats`` computes a reference energy ``e_ref[t, e]`` for
33+
every ``(absorbing_type t, edge_index e)`` combination from the training
34+
data and stores it as a registered buffer. During training, ``forward``
35+
normalises labels and predictions by subtracting the per-frame reference
36+
so that the loss is computed on chemical shifts (±few eV) and normalised
37+
intensities—quantities of comparable magnitude.
38+
39+
The buffer is saved in the model checkpoint, eliminating any need for
40+
external normalisation files.
41+
3142
Parameters
3243
----------
3344
task_dim : int
3445
Output dimension of the fitting net (e.g. 102 = E_min + E_max + 100 pts).
46+
ntypes : int
47+
Number of atom types in the model.
48+
nfparam : int
49+
Length of the fparam one-hot vector (= number of edge types).
3550
var_name : str
3651
Property name, must match ``property_name`` in the fitting config.
3752
loss_func : str
@@ -45,6 +60,8 @@ class XASLoss(TaskLoss):
4560
def __init__(
4661
self,
4762
task_dim: int,
63+
ntypes: int,
64+
nfparam: int,
4865
var_name: str = "xas",
4966
loss_func: str = "smooth_mae",
5067
metric: list[str] = ["mae"],
@@ -53,11 +70,141 @@ def __init__(
5370
) -> None:
5471
super().__init__()
5572
self.task_dim = task_dim
73+
self.ntypes = ntypes
74+
self.nfparam = nfparam
5675
self.var_name = var_name
5776
self.loss_func = loss_func
5877
self.metric = metric
5978
self.beta = beta
6079

80+
# e_ref[sel_type_idx, edge_idx, 0] = mean E_min (eV)
81+
# e_ref[sel_type_idx, edge_idx, 1] = mean E_max (eV)
82+
# Shape: [ntypes, nfparam, 2]. Filled by compute_output_stats; zero until then.
83+
self.register_buffer(
84+
"e_ref",
85+
torch.zeros(ntypes, nfparam, 2, dtype=env.GLOBAL_PT_FLOAT_PRECISION),
86+
)
87+
88+
# ------------------------------------------------------------------
89+
# Stat phase: compute per-(absorbing_type, edge) reference energies
90+
# ------------------------------------------------------------------
91+
def compute_output_stats(
92+
self,
93+
sampled: list[dict],
94+
model: "torch.nn.Module | None" = None,
95+
) -> None:
96+
"""Compute ``e_ref`` and fix model energy-dim bias/std.
97+
98+
Called once before training starts. Requires ``xas``, ``sel_type``,
99+
and ``fparam`` in at least some samples.
100+
101+
Parameters
102+
----------
103+
sampled : list[dict]
104+
List of data batches from ``make_stat_input``.
105+
model : nn.Module, optional
106+
The full DeePMD model. When given, the per-atom property model's
107+
``out_bias`` and ``out_std`` for the two energy dimensions (E_min,
108+
E_max) are reset to 0 / 1 so the NN predicts *chemical shifts*
109+
(±few eV) instead of absolute energies (~thousands of eV).
110+
Without this reset the stat-initialised ``out_std ≈ 26 000 eV``
111+
amplifies weight-update steps by 26 000×, causing immediate
112+
gradient explosion.
113+
"""
114+
accum: dict[tuple[int, int], list] = defaultdict(list)
115+
116+
for frame in sampled:
117+
if (
118+
self.var_name not in frame
119+
or "sel_type" not in frame
120+
or "fparam" not in frame
121+
):
122+
continue
123+
xas = frame[self.var_name] # tensor, various shapes
124+
sel_type = frame["sel_type"]
125+
fparam = frame["fparam"]
126+
127+
# flatten to [nf, task_dim], [nf], [nf, nfparam]
128+
xas = xas.reshape(-1, self.task_dim)
129+
sel_type = sel_type.reshape(-1).long()
130+
fparam = fparam.reshape(-1, self.nfparam)
131+
edge_idx = fparam.argmax(dim=-1)
132+
133+
nf = xas.shape[0]
134+
for i in range(nf):
135+
t = int(sel_type[i].item())
136+
e = int(edge_idx[i].item())
137+
if 0 <= t < self.ntypes and 0 <= e < self.nfparam:
138+
accum[(t, e)].append(xas[i, :2].detach().cpu().numpy())
139+
140+
if not accum:
141+
log.warning(
142+
"XASLoss.compute_output_stats: no frames with xas+sel_type+fparam found; "
143+
"e_ref remains zero. Training may be unstable."
144+
)
145+
return
146+
147+
e_ref = torch.zeros(
148+
self.ntypes, self.nfparam, 2, dtype=env.GLOBAL_PT_FLOAT_PRECISION
149+
)
150+
for (t, e), vals in accum.items():
151+
e_ref[t, e] = torch.tensor(
152+
np.mean(vals, axis=0), dtype=env.GLOBAL_PT_FLOAT_PRECISION
153+
)
154+
log.info(
155+
f"XASLoss e_ref: type={t}, edge={e} -> "
156+
f"E_min_ref={float(e_ref[t,e,0]):.2f} eV, "
157+
f"E_max_ref={float(e_ref[t,e,1]):.2f} eV "
158+
f"(n={len(vals)})"
159+
)
160+
161+
self.e_ref.copy_(e_ref)
162+
log.info(
163+
f"XASLoss: e_ref computed for {len(accum)} (sel_type, edge) combinations."
164+
)
165+
166+
if model is not None:
167+
try:
168+
am = model.atomic_model
169+
170+
# 1. Copy e_ref into the model's own buffer so it is saved
171+
# in the checkpoint and available at inference time without
172+
# any external reference file (analogous to out_bias).
173+
if getattr(am, "xas_e_ref", None) is not None:
174+
am.xas_e_ref.copy_(e_ref.to(am.xas_e_ref.dtype))
175+
log.info("XASLoss: copied e_ref → model.atomic_model.xas_e_ref.")
176+
177+
# 2. Reset energy-dim out_bias/out_std so the NN predicts
178+
# chemical shifts instead of absolute energies.
179+
#
180+
# Why this is necessary
181+
# ----------------------
182+
# The model stat phase initialises
183+
# out_bias[:, :2] ≈ global_mean(E_min, E_max) ≈ 19 000 eV
184+
# out_std[:, :2] ≈ global_std(E_min, E_max) ≈ 26 000 eV
185+
# so atom_xas[:, 0] = NN_raw[:, 0] * 26 000 + 19 000.
186+
# A single Adam step changes NN_raw by ~lr, which changes
187+
# the physical output by lr × 26 000 = 2.7 eV — the same
188+
# instability as out_bias for energy fitting if the reference
189+
# is wrong. With out_std=1 / out_bias=0, the NN output for
190+
# energy dims is interpreted directly as a chemical shift
191+
# (target ≈ label − e_ref ≈ ±few eV), keeping gradient
192+
# magnitudes O(1) and training stable.
193+
key_idx = am.bias_keys.index(self.var_name)
194+
with torch.no_grad():
195+
am.out_bias[key_idx, :, :2] = 0.0
196+
am.out_std[key_idx, :, :2] = 1.0
197+
log.info(
198+
"XASLoss: reset out_bias[:,:2]=0 and out_std[:,:2]=1 "
199+
"for energy dims (model predicts chemical shifts; "
200+
"xas_e_ref restores absolute energies at inference)."
201+
)
202+
except Exception as exc:
203+
log.warning(f"XASLoss: could not update model energy-dim stats: {exc}")
204+
205+
# ------------------------------------------------------------------
206+
# Forward
207+
# ------------------------------------------------------------------
61208
def forward(
62209
self,
63210
input_dict: dict[str, torch.Tensor],
@@ -76,7 +223,7 @@ def forward(
76223
# sel_type from label: [nf, 1] float → [nf] int
77224
sel_type = label["sel_type"][:, 0].long()
78225

79-
# element-wise mean: for each frame average over atoms of sel_type
226+
# element-wise mean: average atom_prop over atoms of sel_type per frame
80227
nf, nloc, td = atom_prop.shape
81228
pred = torch.zeros(nf, td, dtype=atom_prop.dtype, device=atom_prop.device)
82229
for i in range(nf):
@@ -87,27 +234,60 @@ def forward(
87234

88235
label_xas = label[self.var_name] # [nf, task_dim]
89236

237+
# --- per-frame reference energy lookup ---
238+
# edge_idx = argmax of one-hot fparam
239+
fparam = input_dict.get("fparam")
240+
if fparam is not None and fparam.numel() > 0:
241+
edge_idx = fparam.reshape(nf, -1).argmax(dim=-1).clamp(0, self.nfparam - 1)
242+
else:
243+
edge_idx = torch.zeros(nf, dtype=torch.long, device=pred.device)
244+
245+
# e_ref_frame: [nf, 2] (E_min_ref, E_max_ref for each frame)
246+
e_ref_frame = self.e_ref[sel_type, edge_idx] # [nf, 2]
247+
248+
# Shift the energy-dim TARGETS only.
249+
#
250+
# After compute_output_stats has reset out_bias[:,:2]=0 / out_std[:,:2]=1,
251+
# the model outputs raw NN values ≈ 0 for dims 0,1. We train those
252+
# dims against (label − e_ref), i.e. the chemical shift (±few eV),
253+
# keeping gradient magnitudes O(1). Intensity dims (2:) are trained
254+
# against the original label values unchanged.
255+
#
256+
# At inference, we add e_ref back to get the absolute edge energy.
257+
label_shifted = label_xas.clone()
258+
label_shifted[:, :2] = label_xas[:, :2] - e_ref_frame
259+
260+
# --- loss ---
90261
loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
91262
if self.loss_func == "smooth_mae":
92-
loss += F.smooth_l1_loss(pred, label_xas, reduction="sum", beta=self.beta)
263+
loss += F.smooth_l1_loss(
264+
pred, label_shifted, reduction="sum", beta=self.beta
265+
)
93266
elif self.loss_func == "mae":
94-
loss += F.l1_loss(pred, label_xas, reduction="sum")
267+
loss += F.l1_loss(pred, label_shifted, reduction="sum")
95268
elif self.loss_func == "mse":
96-
loss += F.mse_loss(pred, label_xas, reduction="sum")
269+
loss += F.mse_loss(pred, label_shifted, reduction="sum")
97270
elif self.loss_func == "rmse":
98-
loss += torch.sqrt(F.mse_loss(pred, label_xas, reduction="mean"))
271+
loss += torch.sqrt(F.mse_loss(pred, label_shifted, reduction="mean"))
99272
else:
100273
raise RuntimeError(f"Unknown loss function: {self.loss_func}")
101274

275+
# --- metrics ---
102276
more_loss: dict[str, torch.Tensor] = {}
103277
if "mae" in self.metric:
104-
more_loss["mae"] = F.l1_loss(pred, label_xas, reduction="mean").detach()
278+
more_loss["mae"] = F.l1_loss(
279+
pred, label_shifted, reduction="mean"
280+
).detach()
105281
if "rmse" in self.metric:
106282
more_loss["rmse"] = torch.sqrt(
107-
F.mse_loss(pred, label_xas, reduction="mean")
283+
F.mse_loss(pred, label_shifted, reduction="mean")
108284
).detach()
109285

110-
model_pred[self.var_name] = pred
286+
# Absolute prediction: add e_ref back to energy dims for eval / output
287+
pred_abs = pred.clone()
288+
pred_abs[:, :2] = pred[:, :2] + e_ref_frame
289+
model_pred[self.var_name] = pred_abs
290+
111291
return model_pred, loss, more_loss
112292

113293
@property

deepmd/pt/model/atomic_model/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
)
4242
from .property_atomic_model import (
4343
DPPropertyAtomicModel,
44+
DPXASAtomicModel,
4445
)
4546

4647
__all__ = [
@@ -51,6 +52,7 @@
5152
"DPEnergyAtomicModel",
5253
"DPPolarAtomicModel",
5354
"DPPropertyAtomicModel",
55+
"DPXASAtomicModel",
5456
"DPZBLLinearEnergyAtomicModel",
5557
"LinearEnergyAtomicModel",
5658
"PairTabAtomicModel",

0 commit comments

Comments
 (0)