Skip to content

Commit 0b13512

Browse files
committed
update add_chg_spin_ebd for default fparam
1 parent d0c042f commit 0b13512

1 file changed

Lines changed: 12 additions & 1 deletion

File tree

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,24 @@ def forward_atomic(
253253
atype = extended_atype[:, :nloc]
254254
if self.do_grad_r() or self.do_grad_c():
255255
extended_coord.requires_grad_(True)
256+
257+
if self.fitting_net.get_dim_fparam() > 0 and fparam is None:
258+
# use default fparam
259+
default_fparam_tensor = self.fitting_net.get_default_fparam()
260+
assert default_fparam_tensor is not None
261+
fparam_input_for_des = torch.tile(
262+
default_fparam_tensor.unsqueeze(0), [nframes, 1]
263+
)
264+
else:
265+
fparam_input_for_des = fparam
266+
256267
descriptor, rot_mat, g2, h2, sw = self.descriptor(
257268
extended_coord,
258269
extended_atype,
259270
nlist,
260271
mapping=mapping,
261272
comm_dict=comm_dict,
262-
fparam=fparam,
273+
fparam=fparam_input_for_des,
263274
)
264275
assert descriptor is not None
265276
if self.enable_eval_descriptor_hook:

0 commit comments

Comments
 (0)