@@ -44,7 +44,6 @@ class DipoleChargeModifier(BaseModifier):
4444 Splitting parameter of the Ewald sum. Unit: A^{-1}
4545 """
4646
47-
4847 def __init__ (
4948 self ,
5049 model_name : str | None ,
@@ -224,7 +223,9 @@ def forward(
224223 chunk_charge = torch .split (
225224 extended_charge .reshape (nframes , - 1 ), self .dp_batch_size , dim = 0
226225 )
227- for _coord , _box , _charge in zip (chunk_coord , chunk_box , chunk_charge , strict = True ):
226+ for _coord , _box , _charge in zip (
227+ chunk_coord , chunk_box , chunk_charge , strict = True
228+ ):
228229 self .er (
229230 _coord ,
230231 _box ,
@@ -244,9 +245,7 @@ def forward(
244245 tot_v = calc_grads (tot_e , input_box )
245246 tot_v = torch .reshape (tot_v , (nframes , 3 , 3 ))
246247 # nframe, 3, 3
247- tot_v = - torch .matmul (
248- tot_v .transpose (2 , 1 ), input_box .reshape (nframes , 3 , 3 )
249- )
248+ tot_v = - torch .matmul (tot_v .transpose (2 , 1 ), input_box .reshape (nframes , 3 , 3 ))
250249
251250 modifier_pred ["energy" ] = tot_e
252251 modifier_pred ["force" ] = tot_f
0 commit comments