diff --git a/deepmd/pd/model/model/transform_output.py b/deepmd/pd/model/model/transform_output.py index 469bfd3168..ccbeb8e56a 100644 --- a/deepmd/pd/model/model/transform_output.py +++ b/deepmd/pd/model/model/transform_output.py @@ -82,7 +82,9 @@ def task_deriv_one( assert extended_force is not None extended_force = -extended_force if do_virial: - extended_virial = extended_force.unsqueeze(-1) @ extended_coord.unsqueeze(-2) + extended_virial = paddle.einsum( + "...ik,...ij->...ikj", extended_force, extended_coord + ) # the correction sums to zero, which does not contribute to global virial if do_atomic_virial: extended_virial_corr = atomic_virial_corr(extended_coord, atom_energy) diff --git a/deepmd/pt/model/model/transform_output.py b/deepmd/pt/model/model/transform_output.py index e15eda6a1d..b8f1e024e0 100644 --- a/deepmd/pt/model/model/transform_output.py +++ b/deepmd/pt/model/model/transform_output.py @@ -85,7 +85,9 @@ def task_deriv_one( assert extended_force is not None extended_force = -extended_force if do_virial: - extended_virial = extended_force.unsqueeze(-1) @ extended_coord.unsqueeze(-2) + extended_virial = torch.einsum( + "...ik,...ij->...ikj", extended_force, extended_coord + ) # the correction sums to zero, which does not contribute to global virial if do_atomic_virial: extended_virial_corr = atomic_virial_corr(extended_coord, atom_energy)