Skip to content

Commit afe7af5

Browse files
Copilotnjzjz
andcommitted
Fix torch.tensor warnings by using clone().detach() for existing tensors
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 7492ef6 commit afe7af5

File tree

2 files changed

+31
-13
lines changed

2 files changed

+31
-13
lines changed

deepmd/pt/utils/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,9 @@ def to_torch_tensor(
244244
if xx is None:
245245
return None
246246
assert xx is not None
247+
# Handle PyTorch tensors - clone and move to target device/dtype if needed
248+
if isinstance(xx, torch.Tensor):
249+
return xx.clone().detach().to(device=DEVICE)
247250
if not isinstance(xx, np.ndarray):
248251
return xx
249252
# Create a reverse mapping of NP_PRECISION_DICT

source/tests/pt/common.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ def eval_model(
7979
if spins is not None:
8080
assert isinstance(spins, torch.Tensor), err_msg
8181
assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list)
82-
atom_types = torch.tensor(atom_types, dtype=torch.int32, device=DEVICE)
82+
if isinstance(atom_types, torch.Tensor):
83+
atom_types = atom_types.clone().detach().to(dtype=torch.int32, device=DEVICE)
84+
else:
85+
atom_types = torch.tensor(atom_types, dtype=torch.int32, device=DEVICE)
8386
elif isinstance(coords, np.ndarray):
8487
if cells is not None:
8588
assert isinstance(cells, np.ndarray), err_msg
@@ -101,28 +104,40 @@ def eval_model(
101104
else:
102105
natoms = len(atom_types[0])
103106

104-
coord_input = torch.tensor(
105-
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
106-
)
107+
if isinstance(coords, torch.Tensor):
108+
coord_input = coords.reshape([-1, natoms, 3]).clone().detach().to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE)
109+
else:
110+
coord_input = torch.tensor(
111+
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
112+
)
107113
spin_input = None
108114
if spins is not None:
109-
spin_input = torch.tensor(
110-
spins.reshape([-1, natoms, 3]),
111-
dtype=GLOBAL_PT_FLOAT_PRECISION,
112-
device=DEVICE,
113-
)
115+
if isinstance(spins, torch.Tensor):
116+
spin_input = spins.reshape([-1, natoms, 3]).clone().detach().to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE)
117+
else:
118+
spin_input = torch.tensor(
119+
spins.reshape([-1, natoms, 3]),
120+
dtype=GLOBAL_PT_FLOAT_PRECISION,
121+
device=DEVICE,
122+
)
114123
has_spin = getattr(model, "has_spin", False)
115124
if callable(has_spin):
116125
has_spin = has_spin()
117-
type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
126+
if isinstance(atom_types, torch.Tensor):
127+
type_input = atom_types.clone().detach().to(dtype=torch.long, device=DEVICE)
128+
else:
129+
type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
118130
box_input = None
119131
if cells is None:
120132
pbc = False
121133
else:
122134
pbc = True
123-
box_input = torch.tensor(
124-
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
125-
)
135+
if isinstance(cells, torch.Tensor):
136+
box_input = cells.reshape([-1, 3, 3]).clone().detach().to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE)
137+
else:
138+
box_input = torch.tensor(
139+
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
140+
)
126141
num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size)
127142

128143
for ii in range(num_iter):

0 commit comments

Comments
 (0)