Skip to content

Commit 94d2a5a

Browse files
committed
fix: change XAS loss reduction from mean to sum for atomic contributions
1 parent da895d0 commit 94d2a5a

1 file changed

Lines changed: 5 additions & 7 deletions

File tree

deepmd/pt/loss/xas.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -316,14 +316,12 @@ def forward(
316316
# sel_type from label: [nf, 1] float → [nf] int
317317
sel_type = label["sel_type"][:, 0].long()
318318

319-
# element-wise mean: average atom_prop over atoms of sel_type per frame
319+
# Sum atomic contributions over atoms of sel_type per frame.
320+
# The label represents the total XAS spectrum from all sel_type atoms
321+
# in the supercell, so the correct reduction is sum (not mean).
320322
nf, nloc, td = atom_prop.shape
321-
pred = torch.zeros(nf, td, dtype=atom_prop.dtype, device=atom_prop.device)
322-
for i in range(nf):
323-
t = int(sel_type[i].item())
324-
mask = (atype[i] == t).unsqueeze(-1) # [nloc, 1]
325-
count = mask.sum().clamp(min=1)
326-
pred[i] = (atom_prop[i] * mask).sum(dim=0) / count
323+
mask_3d = (atype.unsqueeze(-1) == sel_type.view(nf, 1, 1)) # [nf, nloc, 1]
324+
pred = (atom_prop * mask_3d).sum(dim=1) # [nf, td]
327325

328326
label_xas = label[self.var_name] # [nf, task_dim]
329327

0 commit comments

Comments
 (0)