Skip to content

Commit a5536ab

Browse files
Copilotnjzjz
andcommitted
Remove redundant tensor handling in to_torch_tensor and fix test tensor usage
- Removed lines 247-249 in to_torch_tensor as they were redundant - line 250-251 already handle non-numpy inputs - Fixed TestCalculator and TestCalculatorWithFparamAparam to convert PyTorch tensors to numpy arrays before passing to ASE calculator - This prevents tensor construction warnings by avoiding torch.tensor() calls on existing tensors Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 5eade9b commit a5536ab

2 files changed

Lines changed: 18 additions & 13 deletions

File tree

deepmd/pt/utils/utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,6 @@ def to_torch_tensor(
244244
if xx is None:
245245
return None
246246
assert xx is not None
247-
# Handle PyTorch tensors - clone and move to target device/dtype if needed
248-
if isinstance(xx, torch.Tensor):
249-
return xx.clone().detach().to(device=DEVICE)
250247
if not isinstance(xx, np.ndarray):
251248
return xx
252249
# Create a reverse mapping of NP_PRECISION_DICT

source/tests/pt/test_calculator.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,18 @@ def test_calculator(self) -> None:
6464
atomic_numbers = [1, 1, 1, 8, 8]
6565
idx_perm = [1, 0, 4, 3, 2]
6666

67+
# Convert tensors to numpy for ASE compatibility
68+
cell_np = cell.numpy()
69+
coord_np = coord.numpy()
70+
6771
prec = 1e-10
6872
low_prec = 1e-4
6973

7074
ase_atoms0 = Atoms(
7175
numbers=atomic_numbers,
72-
positions=coord,
76+
positions=coord_np,
7377
# positions=[tuple(item) for item in coordinate],
74-
cell=cell,
78+
cell=cell_np,
7579
calculator=self.calculator,
7680
pbc=True,
7781
)
@@ -83,9 +87,9 @@ def test_calculator(self) -> None:
8387

8488
ase_atoms1 = Atoms(
8589
numbers=[atomic_numbers[i] for i in idx_perm],
86-
positions=coord[idx_perm, :],
90+
positions=coord_np[idx_perm, :],
8791
# positions=[tuple(item) for item in coordinate],
88-
cell=cell,
92+
cell=cell_np,
8993
calculator=self.calculator,
9094
pbc=True,
9195
)
@@ -141,19 +145,23 @@ def test_calculator(self) -> None:
141145
generator = torch.Generator(device="cpu").manual_seed(GLOBAL_SEED)
142146
coord = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator)
143147
coord = torch.matmul(coord, cell)
144-
fparam = torch.IntTensor([1, 2])
145-
aparam = torch.IntTensor([[1], [0], [2], [1], [0]])
148+
fparam = torch.IntTensor([1, 2]).numpy()
149+
aparam = torch.IntTensor([[1], [0], [2], [1], [0]]).numpy()
146150
atomic_numbers = [1, 1, 1, 8, 8]
147151
idx_perm = [1, 0, 4, 3, 2]
148152

153+
# Convert tensors to numpy for ASE compatibility
154+
cell_np = cell.numpy()
155+
coord_np = coord.numpy()
156+
149157
prec = 1e-10
150158
low_prec = 1e-4
151159

152160
ase_atoms0 = Atoms(
153161
numbers=atomic_numbers,
154-
positions=coord,
162+
positions=coord_np,
155163
# positions=[tuple(item) for item in coordinate],
156-
cell=cell,
164+
cell=cell_np,
157165
calculator=self.calculator,
158166
pbc=True,
159167
)
@@ -166,9 +174,9 @@ def test_calculator(self) -> None:
166174

167175
ase_atoms1 = Atoms(
168176
numbers=[atomic_numbers[i] for i in idx_perm],
169-
positions=coord[idx_perm, :],
177+
positions=coord_np[idx_perm, :],
170178
# positions=[tuple(item) for item in coordinate],
171-
cell=cell,
179+
cell=cell_np,
172180
calculator=self.calculator,
173181
pbc=True,
174182
)

0 commit comments

Comments
 (0)