Skip to content

Commit 4f4bd3f

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 0e9ae51 commit 4f4bd3f

File tree

2 files changed

+51
-11
lines changed

2 files changed

+51
-11
lines changed

source/tests/pd/common.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def eval_model(
8080
assert isinstance(spins, paddle.Tensor), err_msg
8181
assert isinstance(atom_types, paddle.Tensor) or isinstance(atom_types, list)
8282
if isinstance(atom_types, paddle.Tensor):
83-
atom_types = atom_types.clone().detach().to(dtype=paddle.int32, place=DEVICE)
83+
atom_types = (
84+
atom_types.clone().detach().to(dtype=paddle.int32, place=DEVICE)
85+
)
8486
else:
8587
atom_types = paddle.to_tensor(atom_types, dtype=paddle.int32, place=DEVICE)
8688
elif isinstance(coords, np.ndarray):
@@ -105,15 +107,27 @@ def eval_model(
105107
natoms = len(atom_types[0])
106108

107109
if isinstance(coords, paddle.Tensor):
108-
coord_input = coords.reshape([-1, natoms, 3]).clone().detach().to(dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE)
110+
coord_input = (
111+
coords.reshape([-1, natoms, 3])
112+
.clone()
113+
.detach()
114+
.to(dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE)
115+
)
109116
else:
110117
coord_input = paddle.to_tensor(
111-
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE
118+
coords.reshape([-1, natoms, 3]),
119+
dtype=GLOBAL_PD_FLOAT_PRECISION,
120+
place=DEVICE,
112121
)
113122
spin_input = None
114123
if spins is not None:
115124
if isinstance(spins, paddle.Tensor):
116-
spin_input = spins.reshape([-1, natoms, 3]).clone().detach().to(dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE)
125+
spin_input = (
126+
spins.reshape([-1, natoms, 3])
127+
.clone()
128+
.detach()
129+
.to(dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE)
130+
)
117131
else:
118132
spin_input = paddle.to_tensor(
119133
spins.reshape([-1, natoms, 3]),
@@ -133,7 +147,12 @@ def eval_model(
133147
else:
134148
pbc = True
135149
if isinstance(cells, paddle.Tensor):
136-
box_input = cells.reshape([-1, 3, 3]).clone().detach().to(dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE)
150+
box_input = (
151+
cells.reshape([-1, 3, 3])
152+
.clone()
153+
.detach()
154+
.to(dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE)
155+
)
137156
else:
138157
box_input = paddle.to_tensor(
139158
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE

source/tests/pt/common.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def eval_model(
8080
assert isinstance(spins, torch.Tensor), err_msg
8181
assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list)
8282
if isinstance(atom_types, torch.Tensor):
83-
atom_types = atom_types.clone().detach().to(dtype=torch.int32, device=DEVICE)
83+
atom_types = (
84+
atom_types.clone().detach().to(dtype=torch.int32, device=DEVICE)
85+
)
8486
else:
8587
atom_types = torch.tensor(atom_types, dtype=torch.int32, device=DEVICE)
8688
elif isinstance(coords, np.ndarray):
@@ -105,15 +107,27 @@ def eval_model(
105107
natoms = len(atom_types[0])
106108

107109
if isinstance(coords, torch.Tensor):
108-
coord_input = coords.reshape([-1, natoms, 3]).clone().detach().to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE)
110+
coord_input = (
111+
coords.reshape([-1, natoms, 3])
112+
.clone()
113+
.detach()
114+
.to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE)
115+
)
109116
else:
110117
coord_input = torch.tensor(
111-
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
118+
coords.reshape([-1, natoms, 3]),
119+
dtype=GLOBAL_PT_FLOAT_PRECISION,
120+
device=DEVICE,
112121
)
113122
spin_input = None
114123
if spins is not None:
115124
if isinstance(spins, torch.Tensor):
116-
spin_input = spins.reshape([-1, natoms, 3]).clone().detach().to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE)
125+
spin_input = (
126+
spins.reshape([-1, natoms, 3])
127+
.clone()
128+
.detach()
129+
.to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE)
130+
)
117131
else:
118132
spin_input = torch.tensor(
119133
spins.reshape([-1, natoms, 3]),
@@ -133,10 +147,17 @@ def eval_model(
133147
else:
134148
pbc = True
135149
if isinstance(cells, torch.Tensor):
136-
box_input = cells.reshape([-1, 3, 3]).clone().detach().to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE)
150+
box_input = (
151+
cells.reshape([-1, 3, 3])
152+
.clone()
153+
.detach()
154+
.to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE)
155+
)
137156
else:
138157
box_input = torch.tensor(
139-
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
158+
cells.reshape([-1, 3, 3]),
159+
dtype=GLOBAL_PT_FLOAT_PRECISION,
160+
device=DEVICE,
140161
)
141162
num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size)
142163

0 commit comments

Comments
 (0)