We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 8c21612 commit 250168bCopy full SHA for 250168b
1 file changed
deepmd/pt/loss/xas.py
@@ -253,7 +253,9 @@ def forward(
253
edge_idx = torch.zeros(nf, dtype=torch.long, device=pred.device)
254
255
# e_ref_frame: [nf, 2] (E_min_ref, E_max_ref for each frame)
256
- e_ref_frame = self.e_ref[sel_type, edge_idx] # [nf, 2]
+ # Indices must be on the same device as the buffer (handles CPU/GPU mismatch)
257
+ _dev = self.e_ref.device
258
+ e_ref_frame = self.e_ref[sel_type.to(_dev), edge_idx.to(_dev)].to(pred.device)
259
260
# Shift the energy-dim TARGETS only.
261
#
0 commit comments