Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion deepmd/pd/model/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def task_deriv_one(
assert extended_force is not None
extended_force = -extended_force
if do_virial:
extended_virial = extended_force.unsqueeze(-1) @ extended_coord.unsqueeze(-2)
extended_virial = paddle.einsum(
"...ik,...ij->...ikj", extended_force, extended_coord
)
# the correction sums to zero, which does not contribute to global virial
if do_atomic_virial:
extended_virial_corr = atomic_virial_corr(extended_coord, atom_energy)
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def task_deriv_one(
assert extended_force is not None
extended_force = -extended_force
if do_virial:
Comment thread
caic99 marked this conversation as resolved.
extended_virial = extended_force.unsqueeze(-1) @ extended_coord.unsqueeze(-2)
extended_virial = torch.einsum("...ik,...ij->...ikj", extended_force, extended_coord)
# the correction sums to zero, which does not contribute to global virial
if do_atomic_virial:
extended_virial_corr = atomic_virial_corr(extended_coord, atom_energy)
Expand Down
Loading