Skip to content

Commit 623269f

Browse files
committed
fix ut
1 parent fbabf9d commit 623269f

3 files changed

Lines changed: 20 additions & 7 deletions

File tree

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,12 @@ def forward_atomic(
184184
nframes, nloc, nnei = nlist.shape
185185
atype = xp_take_first_n(extended_atype, 1, nloc)
186186

187-
if self.fitting_net.get_dim_fparam() > 0 and fparam is None:
187+
# Handle default fparam if fitting net supports it
188+
if (
189+
hasattr(self.fitting_net, "get_dim_fparam")
190+
and self.fitting_net.get_dim_fparam() > 0
191+
and fparam is None
192+
):
188193
# use default fparam
189194
from deepmd.dpmodel.array_api import (
190195
array_api_compat,
@@ -193,8 +198,11 @@ def forward_atomic(
193198
default_fparam = self.fitting_net.get_default_fparam()
194199
assert default_fparam is not None
195200
xp = array_api_compat.array_namespace(extended_coord)
201+
default_fparam_array = xp.asarray(
202+
default_fparam, dtype=extended_coord.dtype
203+
)
196204
fparam_input_for_des = xp.tile(
197-
xp.reshape(default_fparam, (1, -1)), (nframes, 1)
205+
xp.reshape(default_fparam_array, (1, -1)), (nframes, 1)
198206
)
199207
else:
200208
fparam_input_for_des = fparam

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,12 @@ def forward_atomic(
274274
if self.do_grad_r() or self.do_grad_c():
275275
extended_coord.requires_grad_(True)
276276

277-
if self.fitting_net.get_dim_fparam() > 0 and fparam is None:
277+
# Handle default fparam if fitting net supports it
278+
if (
279+
hasattr(self.fitting_net, "get_dim_fparam")
280+
and self.fitting_net.get_dim_fparam() > 0
281+
and fparam is None
282+
):
278283
# use default fparam
279284
default_fparam_tensor = self.fitting_net.get_default_fparam()
280285
assert default_fparam_tensor is not None

source/tests/consistent/descriptor/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,13 +246,13 @@ def eval_pd_descriptor(
246246
pd_obj.get_sel(),
247247
distinguish_types=(not mixed_types),
248248
)
249-
fparam_pd = (
250-
paddle.to_tensor(fparam).to(PD_DEVICE) if fparam is not None else None
251-
)
252249
return [
253250
x.detach().cpu().numpy() if paddle.is_tensor(x) else x
254251
for x in pd_obj(
255-
ext_coords, ext_atype, nlist=nlist, mapping=mapping, fparam=fparam_pd
252+
ext_coords,
253+
ext_atype,
254+
nlist=nlist,
255+
mapping=mapping,
256256
)
257257
]
258258

0 commit comments

Comments
 (0)