Skip to content

Commit 5eade9b

Browse files
Copilotnjzjz
andcommitted
Fix PaddlePaddle tensor construction warnings and enhance PyTorch to_torch_tensor function
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent afe7af5 commit 5eade9b

2 files changed

Lines changed: 31 additions & 13 deletions

File tree

deepmd/pd/utils/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,9 @@ def to_paddle_tensor(
267267
if xx is None:
268268
return None
269269
assert xx is not None
270+
# Handle PaddlePaddle tensors - clone and move to target device/dtype if needed
271+
if isinstance(xx, paddle.Tensor):
272+
return xx.clone().detach().to(device=DEVICE)
270273
if not isinstance(xx, np.ndarray):
271274
return xx
272275
# Create a reverse mapping of NP_PRECISION_DICT

source/tests/pd/common.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ def eval_model(
7979
if spins is not None:
8080
assert isinstance(spins, paddle.Tensor), err_msg
8181
assert isinstance(atom_types, paddle.Tensor) or isinstance(atom_types, list)
82-
atom_types = paddle.to_tensor(atom_types, dtype=paddle.int32, place=DEVICE)
82+
if isinstance(atom_types, paddle.Tensor):
83+
atom_types = atom_types.clone().detach().to(dtype=paddle.int32, place=DEVICE)
84+
else:
85+
atom_types = paddle.to_tensor(atom_types, dtype=paddle.int32, place=DEVICE)
8386
elif isinstance(coords, np.ndarray):
8487
if cells is not None:
8588
assert isinstance(cells, np.ndarray), err_msg
@@ -101,28 +104,40 @@ def eval_model(
101104
else:
102105
natoms = len(atom_types[0])
103106

104-
coord_input = paddle.to_tensor(
105-
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE
106-
)
107+
if isinstance(coords, paddle.Tensor):
108+
coord_input = coords.reshape([-1, natoms, 3]).clone().detach().to(dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE)
109+
else:
110+
coord_input = paddle.to_tensor(
111+
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE
112+
)
107113
spin_input = None
108114
if spins is not None:
109-
spin_input = paddle.to_tensor(
110-
spins.reshape([-1, natoms, 3]),
111-
dtype=GLOBAL_PD_FLOAT_PRECISION,
112-
place=DEVICE,
113-
)
115+
if isinstance(spins, paddle.Tensor):
116+
spin_input = spins.reshape([-1, natoms, 3]).clone().detach().to(dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE)
117+
else:
118+
spin_input = paddle.to_tensor(
119+
spins.reshape([-1, natoms, 3]),
120+
dtype=GLOBAL_PD_FLOAT_PRECISION,
121+
place=DEVICE,
122+
)
114123
has_spin = getattr(model, "has_spin", False)
115124
if callable(has_spin):
116125
has_spin = has_spin()
117-
type_input = paddle.to_tensor(atom_types, dtype=paddle.int64, place=DEVICE)
126+
if isinstance(atom_types, paddle.Tensor):
127+
type_input = atom_types.clone().detach().to(dtype=paddle.int64, place=DEVICE)
128+
else:
129+
type_input = paddle.to_tensor(atom_types, dtype=paddle.int64, place=DEVICE)
118130
box_input = None
119131
if cells is None:
120132
pbc = False
121133
else:
122134
pbc = True
123-
box_input = paddle.to_tensor(
124-
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE
125-
)
135+
if isinstance(cells, paddle.Tensor):
136+
box_input = cells.reshape([-1, 3, 3]).clone().detach().to(dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE)
137+
else:
138+
box_input = paddle.to_tensor(
139+
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE
140+
)
126141
num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size)
127142

128143
for ii in range(num_iter):

0 commit comments

Comments
 (0)