@@ -132,18 +132,18 @@ def call(
132132 atom_ener_coeff = xp .reshape (atom_ener_coeff , xp .shape (atom_ener ))
133133 energy = xp .sum (atom_ener_coeff * atom_ener , 1 )
134134 if self .has_f or self .has_pf or self .relative_f or self .has_gf :
135- force_reshape = xp .reshape (force , [ - 1 ] )
136- force_hat_reshape = xp .reshape (force_hat , [ - 1 ] )
135+ force_reshape = xp .reshape (force , ( - 1 ,) )
136+ force_hat_reshape = xp .reshape (force_hat , ( - 1 ,) )
137137 diff_f = force_hat_reshape - force_reshape
138138 else :
139139 diff_f = None
140140
141141 if self .relative_f is not None :
142- force_hat_3 = xp .reshape (force_hat , [ - 1 , 3 ] )
143- norm_f = xp .reshape (xp .norm (force_hat_3 , axis = 1 ), [ - 1 , 1 ] ) + self .relative_f
144- diff_f_3 = xp .reshape (diff_f , [ - 1 , 3 ] )
142+ force_hat_3 = xp .reshape (force_hat , ( - 1 , 3 ) )
143+ norm_f = xp .reshape (xp .norm (force_hat_3 , axis = 1 ), ( - 1 , 1 ) ) + self .relative_f
144+ diff_f_3 = xp .reshape (diff_f , ( - 1 , 3 ) )
145145 diff_f_3 = diff_f_3 / norm_f
146- diff_f = xp .reshape (diff_f_3 , [ - 1 ] )
146+ diff_f = xp .reshape (diff_f_3 , ( - 1 ,) )
147147
148148 atom_norm = 1.0 / natoms
149149 atom_norm_ener = 1.0 / natoms
@@ -184,15 +184,15 @@ def call(
184184 loss += pref_f * l2_force_loss
185185 else :
186186 l_huber_loss = custom_huber_loss (
187- xp .reshape (force , [ - 1 ] ),
188- xp .reshape (force_hat , [ - 1 ] ),
187+ xp .reshape (force , ( - 1 ,) ),
188+ xp .reshape (force_hat , ( - 1 ,) ),
189189 delta = self .huber_delta ,
190190 )
191191 loss += pref_f * l_huber_loss
192192 more_loss ["rmse_f" ] = self .display_if_exist (l2_force_loss , find_force )
193193 if self .has_v :
194- virial_reshape = xp .reshape (virial , [ - 1 ] )
195- virial_hat_reshape = xp .reshape (virial_hat , [ - 1 ] )
194+ virial_reshape = xp .reshape (virial , ( - 1 ,) )
195+ virial_hat_reshape = xp .reshape (virial_hat , ( - 1 ,) )
196196 l2_virial_loss = xp .mean (
197197 xp .square (virial_hat_reshape - virial_reshape ),
198198 )
@@ -207,8 +207,8 @@ def call(
207207 loss += pref_v * l_huber_loss
208208 more_loss ["rmse_v" ] = self .display_if_exist (l2_virial_loss , find_virial )
209209 if self .has_ae :
210- atom_ener_reshape = xp .reshape (atom_ener , [ - 1 ] )
211- atom_ener_hat_reshape = xp .reshape (atom_ener_hat , [ - 1 ] )
210+ atom_ener_reshape = xp .reshape (atom_ener , ( - 1 ,) )
211+ atom_ener_hat_reshape = xp .reshape (atom_ener_hat , ( - 1 ,) )
212212 l2_atom_ener_loss = xp .mean (
213213 xp .square (atom_ener_hat_reshape - atom_ener_reshape ),
214214 )
@@ -225,7 +225,7 @@ def call(
225225 l2_atom_ener_loss , find_atom_ener
226226 )
227227 if self .has_pf :
228- atom_pref_reshape = xp .reshape (atom_pref , [ - 1 ] )
228+ atom_pref_reshape = xp .reshape (atom_pref , ( - 1 ,) )
229229 l2_pref_force_loss = xp .mean (
230230 xp .multiply (xp .square (diff_f ), atom_pref_reshape ),
231231 )
@@ -236,10 +236,10 @@ def call(
236236 if self .has_gf :
237237 find_drdq = label_dict ["find_drdq" ]
238238 drdq = label_dict ["drdq" ]
239- force_reshape_nframes = xp .reshape (force , [ - 1 , natoms [0 ] * 3 ] )
240- force_hat_reshape_nframes = xp .reshape (force_hat , [ - 1 , natoms [0 ] * 3 ] )
239+ force_reshape_nframes = xp .reshape (force , ( - 1 , natoms [0 ] * 3 ) )
240+ force_hat_reshape_nframes = xp .reshape (force_hat , ( - 1 , natoms [0 ] * 3 ) )
241241 drdq_reshape = xp .reshape (
242- drdq , [ - 1 , natoms [0 ] * 3 , self .numb_generalized_coord ]
242+ drdq , ( - 1 , natoms [0 ] * 3 , self .numb_generalized_coord )
243243 )
244244 gen_force_hat = xp .einsum (
245245 "bij,bi->bj" , drdq_reshape , force_hat_reshape_nframes
0 commit comments