Skip to content
Merged
3 changes: 3 additions & 0 deletions deepmd/pd/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ def to_paddle_tensor(
if xx is None:
return None
assert xx is not None
# Handle PaddlePaddle tensors - clone and move to target device/dtype if needed
if isinstance(xx, paddle.Tensor):
return xx.clone().detach().to(device=DEVICE)
if not isinstance(xx, np.ndarray):
return xx
# Create a reverse mapping of NP_PRECISION_DICT
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ def to_torch_tensor(
if xx is None:
return None
assert xx is not None
# Handle PyTorch tensors - clone and move to target device/dtype if needed
if isinstance(xx, torch.Tensor):
return xx.clone().detach().to(device=DEVICE)
Comment thread
njzjz marked this conversation as resolved.
Outdated
if not isinstance(xx, np.ndarray):
return xx
# Create a reverse mapping of NP_PRECISION_DICT
Expand Down
41 changes: 28 additions & 13 deletions source/tests/pd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def eval_model(
if spins is not None:
assert isinstance(spins, paddle.Tensor), err_msg
assert isinstance(atom_types, paddle.Tensor) or isinstance(atom_types, list)
atom_types = paddle.to_tensor(atom_types, dtype=paddle.int32, place=DEVICE)
if isinstance(atom_types, paddle.Tensor):
atom_types = atom_types.clone().detach().to(dtype=paddle.int32, place=DEVICE)
else:
atom_types = paddle.to_tensor(atom_types, dtype=paddle.int32, place=DEVICE)
elif isinstance(coords, np.ndarray):
if cells is not None:
assert isinstance(cells, np.ndarray), err_msg
Expand All @@ -101,28 +104,40 @@ def eval_model(
else:
natoms = len(atom_types[0])

coord_input = paddle.to_tensor(
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE
)
if isinstance(coords, paddle.Tensor):
coord_input = coords.reshape([-1, natoms, 3]).clone().detach().to(dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE)
else:
coord_input = paddle.to_tensor(
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE
)
spin_input = None
if spins is not None:
spin_input = paddle.to_tensor(
spins.reshape([-1, natoms, 3]),
dtype=GLOBAL_PD_FLOAT_PRECISION,
place=DEVICE,
)
if isinstance(spins, paddle.Tensor):
spin_input = spins.reshape([-1, natoms, 3]).clone().detach().to(dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE)
else:
spin_input = paddle.to_tensor(
spins.reshape([-1, natoms, 3]),
dtype=GLOBAL_PD_FLOAT_PRECISION,
place=DEVICE,
)
has_spin = getattr(model, "has_spin", False)
if callable(has_spin):
has_spin = has_spin()
type_input = paddle.to_tensor(atom_types, dtype=paddle.int64, place=DEVICE)
if isinstance(atom_types, paddle.Tensor):
type_input = atom_types.clone().detach().to(dtype=paddle.int64, place=DEVICE)
else:
type_input = paddle.to_tensor(atom_types, dtype=paddle.int64, place=DEVICE)
box_input = None
if cells is None:
pbc = False
else:
pbc = True
box_input = paddle.to_tensor(
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE
)
if isinstance(cells, paddle.Tensor):
box_input = cells.reshape([-1, 3, 3]).clone().detach().to(dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE)
else:
box_input = paddle.to_tensor(
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE
)
num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size)

for ii in range(num_iter):
Expand Down
41 changes: 28 additions & 13 deletions source/tests/pt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def eval_model(
if spins is not None:
assert isinstance(spins, torch.Tensor), err_msg
assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list)
atom_types = torch.tensor(atom_types, dtype=torch.int32, device=DEVICE)
if isinstance(atom_types, torch.Tensor):
atom_types = atom_types.clone().detach().to(dtype=torch.int32, device=DEVICE)
else:
atom_types = torch.tensor(atom_types, dtype=torch.int32, device=DEVICE)
elif isinstance(coords, np.ndarray):
if cells is not None:
assert isinstance(cells, np.ndarray), err_msg
Expand All @@ -101,28 +104,40 @@ def eval_model(
else:
natoms = len(atom_types[0])

coord_input = torch.tensor(
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
if isinstance(coords, torch.Tensor):
coord_input = coords.reshape([-1, natoms, 3]).clone().detach().to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE)
else:
coord_input = torch.tensor(
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
spin_input = None
if spins is not None:
spin_input = torch.tensor(
spins.reshape([-1, natoms, 3]),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
if isinstance(spins, torch.Tensor):
spin_input = spins.reshape([-1, natoms, 3]).clone().detach().to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE)
else:
spin_input = torch.tensor(
spins.reshape([-1, natoms, 3]),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
has_spin = getattr(model, "has_spin", False)
if callable(has_spin):
has_spin = has_spin()
type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
if isinstance(atom_types, torch.Tensor):
type_input = atom_types.clone().detach().to(dtype=torch.long, device=DEVICE)
else:
type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
box_input = None
if cells is None:
pbc = False
else:
pbc = True
box_input = torch.tensor(
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
if isinstance(cells, torch.Tensor):
box_input = cells.reshape([-1, 3, 3]).clone().detach().to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE)
else:
box_input = torch.tensor(
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size)

for ii in range(num_iter):
Expand Down