Skip to content

Commit 250168b

Browse files
committed
fix:device
1 parent 8c21612 commit 250168b

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

deepmd/pt/loss/xas.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,9 @@ def forward(
253253
edge_idx = torch.zeros(nf, dtype=torch.long, device=pred.device)
254254

255255
# e_ref_frame: [nf, 2] (E_min_ref, E_max_ref for each frame)
256-
e_ref_frame = self.e_ref[sel_type, edge_idx] # [nf, 2]
256+
# Indices must be on the same device as the buffer (handles CPU/GPU mismatch)
257+
_dev = self.e_ref.device
258+
e_ref_frame = self.e_ref[sel_type.to(_dev), edge_idx.to(_dev)].to(pred.device)
257259

258260
# Shift the energy-dim TARGETS only.
259261
#

0 commit comments

Comments
 (0)