Skip to content

Commit b4083f8

Browse files
committed
Update esen_model.py
1 parent 22e1ec5 commit b4083f8

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

deepmd/pt/model/model/esen_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def forward(
343343
# use default fparam
344344
assert self.default_fparam_tensor is not None
345345
fparam = torch.tile(self.default_fparam_tensor.unsqueeze(0), [nf, 1])
346-
atomic_number = self.atype_to_idx[atype.view(-1)].view(nf, nloc)
346+
atomic_number = self.atype_to_idx[atype.view(-1)].view(nf, nloc) + 1
347347
tags = torch.zeros_like(atomic_number)
348348
pbc = torch.ones([3], dtype=torch.bool, device=coord.device)
349349
fixed_idx = torch.zeros_like(atomic_number)
@@ -368,12 +368,12 @@ def forward(
368368
model_ret = self.model(batch)
369369

370370
# apply energy bias
371-
model_ret["energy"] = model_ret["energy"] + self.out_bias[0, :, 0][atype].sum(-1)
371+
model_ret["energy"] = model_ret["energy"] * 0.8125 + self.out_bias[0, :, 0][atype].sum(-1)
372372

373373
model_predict = {}
374374
model_predict["energy"] = model_ret["energy"]
375-
model_predict["force"] = model_ret["forces"].view(nf, nloc, 3)
376-
model_predict["virial"] = model_ret["virial"].view(nf, 9)
375+
model_predict["force"] = model_ret["forces"].view(nf, nloc, 3)*0.8125
376+
model_predict["virial"] = model_ret["virial"].view(nf, 9)*0.8125
377377
return model_predict
378378

379379
@torch.jit.export

0 commit comments

Comments
 (0)