diff --git a/deepmd/pd/model/model/transform_output.py b/deepmd/pd/model/model/transform_output.py index ccbeb8e56a..469bfd3168 100644 --- a/deepmd/pd/model/model/transform_output.py +++ b/deepmd/pd/model/model/transform_output.py @@ -82,9 +82,7 @@ def task_deriv_one( assert extended_force is not None extended_force = -extended_force if do_virial: - extended_virial = paddle.einsum( - "...ik,...ij->...ikj", extended_force, extended_coord - ) + extended_virial = extended_force.unsqueeze(-1) @ extended_coord.unsqueeze(-2) # the correction sums to zero, which does not contribute to global virial if do_atomic_virial: extended_virial_corr = atomic_virial_corr(extended_coord, atom_energy)