11# SPDX-License-Identifier: LGPL-3.0-or-later
22import logging
3- from typing import (
4- Any ,
5- )
3+ from collections import defaultdict
4+ from typing import Any
65
6+ import numpy as np
77import torch
88import torch .nn .functional as F
99
10- from deepmd .pt .loss .loss import (
11- TaskLoss ,
12- )
13- from deepmd .pt .utils import (
14- env ,
15- )
16- from deepmd .utils .data import (
17- DataRequirementItem ,
18- )
10+ from deepmd .pt .loss .loss import TaskLoss
11+ from deepmd .pt .utils import env
12+ from deepmd .utils .data import DataRequirementItem
1913
2014log = logging .getLogger (__name__ )
2115
@@ -28,10 +22,31 @@ class XASLoss(TaskLoss):
2822 in each training system) and takes their mean, then computes a loss against
2923 the per-frame XAS label.
3024
25+ Energy normalization
26+ --------------------
27+ XAS labels contain absolute edge energies (E_min, E_max in eV) that vary
28+ enormously across element-edge pairs (H_K ~14 eV, Th_K ~110000 eV).
29+ Training directly on absolute values causes gradient instability because
30+ the energy dimensions dwarf the intensity dimensions.
31+
32+ ``compute_output_stats`` computes a reference energy ``e_ref[t, e]`` for
33+ every ``(absorbing_type t, edge_index e)`` combination from the training
34+ data and stores it as a registered buffer. During training, ``forward``
35+ normalises labels and predictions by subtracting the per-frame reference
36+ so that the loss is computed on chemical shifts (±few eV) and normalised
37+ intensities—quantities of comparable magnitude.
38+
39+ The buffer is saved in the model checkpoint, eliminating any need for
40+ external normalisation files.
41+
3142 Parameters
3243 ----------
3344 task_dim : int
3445 Output dimension of the fitting net (e.g. 102 = E_min + E_max + 100 pts).
46+ ntypes : int
47+ Number of atom types in the model.
48+ nfparam : int
49+ Length of the fparam one-hot vector (= number of edge types).
3550 var_name : str
3651 Property name, must match ``property_name`` in the fitting config.
3752 loss_func : str
@@ -45,6 +60,8 @@ class XASLoss(TaskLoss):
4560 def __init__ (
4661 self ,
4762 task_dim : int ,
63+ ntypes : int ,
64+ nfparam : int ,
4865 var_name : str = "xas" ,
4966 loss_func : str = "smooth_mae" ,
5067 metric : list [str ] = ["mae" ],
@@ -53,11 +70,141 @@ def __init__(
5370 ) -> None :
5471 super ().__init__ ()
5572 self .task_dim = task_dim
73+ self .ntypes = ntypes
74+ self .nfparam = nfparam
5675 self .var_name = var_name
5776 self .loss_func = loss_func
5877 self .metric = metric
5978 self .beta = beta
6079
80+ # e_ref[sel_type_idx, edge_idx, 0] = mean E_min (eV)
81+ # e_ref[sel_type_idx, edge_idx, 1] = mean E_max (eV)
82+ # Shape: [ntypes, nfparam, 2]. Filled by compute_output_stats; zero until then.
83+ self .register_buffer (
84+ "e_ref" ,
85+ torch .zeros (ntypes , nfparam , 2 , dtype = env .GLOBAL_PT_FLOAT_PRECISION ),
86+ )
87+
88+ # ------------------------------------------------------------------
89+ # Stat phase: compute per-(absorbing_type, edge) reference energies
90+ # ------------------------------------------------------------------
91+ def compute_output_stats (
92+ self ,
93+ sampled : list [dict ],
94+ model : "torch.nn.Module | None" = None ,
95+ ) -> None :
96+ """Compute ``e_ref`` and fix model energy-dim bias/std.
97+
98+ Called once before training starts. Requires ``xas``, ``sel_type``,
99+ and ``fparam`` in at least some samples.
100+
101+ Parameters
102+ ----------
103+ sampled : list[dict]
104+ List of data batches from ``make_stat_input``.
105+ model : nn.Module, optional
106+ The full DeePMD model. When given, the per-atom property model's
107+ ``out_bias`` and ``out_std`` for the two energy dimensions (E_min,
108+ E_max) are reset to 0 / 1 so the NN predicts *chemical shifts*
109+ (±few eV) instead of absolute energies (~thousands of eV).
110+ Without this reset the stat-initialised ``out_std ≈ 26 000 eV``
111+ amplifies weight-update steps by 26 000×, causing immediate
112+ gradient explosion.
113+ """
114+ accum : dict [tuple [int , int ], list ] = defaultdict (list )
115+
116+ for frame in sampled :
117+ if (
118+ self .var_name not in frame
119+ or "sel_type" not in frame
120+ or "fparam" not in frame
121+ ):
122+ continue
123+ xas = frame [self .var_name ] # tensor, various shapes
124+ sel_type = frame ["sel_type" ]
125+ fparam = frame ["fparam" ]
126+
127+ # flatten to [nf, task_dim], [nf], [nf, nfparam]
128+ xas = xas .reshape (- 1 , self .task_dim )
129+ sel_type = sel_type .reshape (- 1 ).long ()
130+ fparam = fparam .reshape (- 1 , self .nfparam )
131+ edge_idx = fparam .argmax (dim = - 1 )
132+
133+ nf = xas .shape [0 ]
134+ for i in range (nf ):
135+ t = int (sel_type [i ].item ())
136+ e = int (edge_idx [i ].item ())
137+ if 0 <= t < self .ntypes and 0 <= e < self .nfparam :
138+ accum [(t , e )].append (xas [i , :2 ].detach ().cpu ().numpy ())
139+
140+ if not accum :
141+ log .warning (
142+ "XASLoss.compute_output_stats: no frames with xas+sel_type+fparam found; "
143+ "e_ref remains zero. Training may be unstable."
144+ )
145+ return
146+
147+ e_ref = torch .zeros (
148+ self .ntypes , self .nfparam , 2 , dtype = env .GLOBAL_PT_FLOAT_PRECISION
149+ )
150+ for (t , e ), vals in accum .items ():
151+ e_ref [t , e ] = torch .tensor (
152+ np .mean (vals , axis = 0 ), dtype = env .GLOBAL_PT_FLOAT_PRECISION
153+ )
154+ log .info (
155+ f"XASLoss e_ref: type={ t } , edge={ e } -> "
156+ f"E_min_ref={ float (e_ref [t ,e ,0 ]):.2f} eV, "
157+ f"E_max_ref={ float (e_ref [t ,e ,1 ]):.2f} eV "
158+ f"(n={ len (vals )} )"
159+ )
160+
161+ self .e_ref .copy_ (e_ref )
162+ log .info (
163+ f"XASLoss: e_ref computed for { len (accum )} (sel_type, edge) combinations."
164+ )
165+
166+ if model is not None :
167+ try :
168+ am = model .atomic_model
169+
170+ # 1. Copy e_ref into the model's own buffer so it is saved
171+ # in the checkpoint and available at inference time without
172+ # any external reference file (analogous to out_bias).
173+ if getattr (am , "xas_e_ref" , None ) is not None :
174+ am .xas_e_ref .copy_ (e_ref .to (am .xas_e_ref .dtype ))
175+ log .info ("XASLoss: copied e_ref → model.atomic_model.xas_e_ref." )
176+
177+ # 2. Reset energy-dim out_bias/out_std so the NN predicts
178+ # chemical shifts instead of absolute energies.
179+ #
180+ # Why this is necessary
181+ # ----------------------
182+ # The model stat phase initialises
183+ # out_bias[:, :2] ≈ global_mean(E_min, E_max) ≈ 19 000 eV
184+ # out_std[:, :2] ≈ global_std(E_min, E_max) ≈ 26 000 eV
185+ # so atom_xas[:, 0] = NN_raw[:, 0] * 26 000 + 19 000.
186+ # A single Adam step changes NN_raw by ~lr, which changes
187+ # the physical output by lr × 26 000 = 2.7 eV — the same
188+ # instability as out_bias for energy fitting if the reference
189+ # is wrong. With out_std=1 / out_bias=0, the NN output for
190+ # energy dims is interpreted directly as a chemical shift
191+ # (target ≈ label − e_ref ≈ ±few eV), keeping gradient
192+ # magnitudes O(1) and training stable.
193+ key_idx = am .bias_keys .index (self .var_name )
194+ with torch .no_grad ():
195+ am .out_bias [key_idx , :, :2 ] = 0.0
196+ am .out_std [key_idx , :, :2 ] = 1.0
197+ log .info (
198+ "XASLoss: reset out_bias[:,:2]=0 and out_std[:,:2]=1 "
199+ "for energy dims (model predicts chemical shifts; "
200+ "xas_e_ref restores absolute energies at inference)."
201+ )
202+ except Exception as exc :
203+ log .warning (f"XASLoss: could not update model energy-dim stats: { exc } " )
204+
205+ # ------------------------------------------------------------------
206+ # Forward
207+ # ------------------------------------------------------------------
61208 def forward (
62209 self ,
63210 input_dict : dict [str , torch .Tensor ],
@@ -76,7 +223,7 @@ def forward(
76223 # sel_type from label: [nf, 1] float → [nf] int
77224 sel_type = label ["sel_type" ][:, 0 ].long ()
78225
79- # element-wise mean: for each frame average over atoms of sel_type
226+ # element-wise mean: average atom_prop over atoms of sel_type per frame
80227 nf , nloc , td = atom_prop .shape
81228 pred = torch .zeros (nf , td , dtype = atom_prop .dtype , device = atom_prop .device )
82229 for i in range (nf ):
@@ -87,27 +234,60 @@ def forward(
87234
88235 label_xas = label [self .var_name ] # [nf, task_dim]
89236
237+ # --- per-frame reference energy lookup ---
238+ # edge_idx = argmax of one-hot fparam
239+ fparam = input_dict .get ("fparam" )
240+ if fparam is not None and fparam .numel () > 0 :
241+ edge_idx = fparam .reshape (nf , - 1 ).argmax (dim = - 1 ).clamp (0 , self .nfparam - 1 )
242+ else :
243+ edge_idx = torch .zeros (nf , dtype = torch .long , device = pred .device )
244+
245+ # e_ref_frame: [nf, 2] (E_min_ref, E_max_ref for each frame)
246+ e_ref_frame = self .e_ref [sel_type , edge_idx ] # [nf, 2]
247+
248+ # Shift the energy-dim TARGETS only.
249+ #
250+ # After compute_output_stats has reset out_bias[:,:2]=0 / out_std[:,:2]=1,
251+ # the model outputs raw NN values ≈ 0 for dims 0,1. We train those
252+ # dims against (label − e_ref), i.e. the chemical shift (±few eV),
253+ # keeping gradient magnitudes O(1). Intensity dims (2:) are trained
254+ # against the original label values unchanged.
255+ #
256+ # At inference, we add e_ref back to get the absolute edge energy.
257+ label_shifted = label_xas .clone ()
258+ label_shifted [:, :2 ] = label_xas [:, :2 ] - e_ref_frame
259+
260+ # --- loss ---
90261 loss = torch .zeros (1 , dtype = env .GLOBAL_PT_FLOAT_PRECISION , device = env .DEVICE )[0 ]
91262 if self .loss_func == "smooth_mae" :
92- loss += F .smooth_l1_loss (pred , label_xas , reduction = "sum" , beta = self .beta )
263+ loss += F .smooth_l1_loss (
264+ pred , label_shifted , reduction = "sum" , beta = self .beta
265+ )
93266 elif self .loss_func == "mae" :
94- loss += F .l1_loss (pred , label_xas , reduction = "sum" )
267+ loss += F .l1_loss (pred , label_shifted , reduction = "sum" )
95268 elif self .loss_func == "mse" :
96- loss += F .mse_loss (pred , label_xas , reduction = "sum" )
269+ loss += F .mse_loss (pred , label_shifted , reduction = "sum" )
97270 elif self .loss_func == "rmse" :
98- loss += torch .sqrt (F .mse_loss (pred , label_xas , reduction = "mean" ))
271+ loss += torch .sqrt (F .mse_loss (pred , label_shifted , reduction = "mean" ))
99272 else :
100273 raise RuntimeError (f"Unknown loss function: { self .loss_func } " )
101274
275+ # --- metrics ---
102276 more_loss : dict [str , torch .Tensor ] = {}
103277 if "mae" in self .metric :
104- more_loss ["mae" ] = F .l1_loss (pred , label_xas , reduction = "mean" ).detach ()
278+ more_loss ["mae" ] = F .l1_loss (
279+ pred , label_shifted , reduction = "mean"
280+ ).detach ()
105281 if "rmse" in self .metric :
106282 more_loss ["rmse" ] = torch .sqrt (
107- F .mse_loss (pred , label_xas , reduction = "mean" )
283+ F .mse_loss (pred , label_shifted , reduction = "mean" )
108284 ).detach ()
109285
110- model_pred [self .var_name ] = pred
286+ # Absolute prediction: add e_ref back to energy dims for eval / output
287+ pred_abs = pred .clone ()
288+ pred_abs [:, :2 ] = pred [:, :2 ] + e_ref_frame
289+ model_pred [self .var_name ] = pred_abs
290+
111291 return model_pred , loss , more_loss
112292
113293 @property
0 commit comments