@@ -812,3 +812,181 @@ def test_intensive_vs_legacy_scaling_difference(self) -> None:
812812 places = 5 ,
813813 msg = f"Expected intensive/legacy ratio ~{ expected_ratio :.6f} , got { actual_ratio :.6f} " ,
814814 )
815+
816+
817+ class TestEnerDefaultPf (CommonTest , LossTest , unittest .TestCase ):
818+ """Test energy loss with use_default_pf=True.
819+
820+ The pf term is activated through the default atom_pref of 1.0 even though
821+ `find_atom_pref` is 0.0 in the label. This exercises the cross-backend
822+ consistency between PT and DP for the new option. TF and Paddle backends
823+ raise NotImplementedError when use_default_pf=True and are skipped.
824+ """
825+
826+ @property
827+ def data (self ) -> dict :
828+ return {
829+ "start_pref_e" : 0.02 ,
830+ "limit_pref_e" : 1.0 ,
831+ "start_pref_f" : 1000.0 ,
832+ "limit_pref_f" : 1.0 ,
833+ "start_pref_v" : 1.0 ,
834+ "limit_pref_v" : 1.0 ,
835+ "start_pref_ae" : 1.0 ,
836+ "limit_pref_ae" : 1.0 ,
837+ "start_pref_pf" : 1.0 ,
838+ "limit_pref_pf" : 1.0 ,
839+ "use_default_pf" : True ,
840+ }
841+
842+ skip_tf = True
843+ skip_pd = True
844+ skip_pt = CommonTest .skip_pt
845+ skip_pt_expt = not INSTALLED_PT_EXPT
846+ skip_jax = not INSTALLED_JAX
847+ skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT
848+
849+ tf_class = EnerLossTF
850+ dp_class = EnerLossDP
851+ pt_class = EnerLossPT
852+ pt_expt_class = EnerLossPTExpt
853+ jax_class = EnerLossDP
854+ pd_class = EnerLossPD
855+ array_api_strict_class = EnerLossDP
856+ args = loss_ener ()
857+
858+ def setUp (self ) -> None :
859+ CommonTest .setUp (self )
860+ self .learning_rate = 1e-3
861+ rng = np .random .default_rng (20250105 )
862+ self .nframes = 2
863+ self .natoms = 6
864+ self .predict = {
865+ "energy" : rng .random ((self .nframes ,)),
866+ "force" : rng .random ((self .nframes , self .natoms , 3 )),
867+ "virial" : rng .random ((self .nframes , 9 )),
868+ "atom_ener" : rng .random ((self .nframes , self .natoms )),
869+ }
870+ self .predict_dpmodel_style = {
871+ "energy" : self .predict ["energy" ],
872+ "force" : self .predict ["force" ],
873+ "virial" : self .predict ["virial" ],
874+ "atom_energy" : self .predict ["atom_ener" ],
875+ }
876+ # find_atom_pref=0.0 simulates the case where atom_pref.npy is missing;
877+ # use_default_pf=True must override this and still compute the pf loss.
878+ self .label = {
879+ "energy" : rng .random ((self .nframes ,)),
880+ "force" : rng .random ((self .nframes , self .natoms , 3 )),
881+ "virial" : rng .random ((self .nframes , 9 )),
882+ "atom_ener" : rng .random ((self .nframes , self .natoms )),
883+ "atom_pref" : np .ones ((self .nframes , self .natoms , 3 )),
884+ "find_energy" : 1.0 ,
885+ "find_force" : 1.0 ,
886+ "find_virial" : 1.0 ,
887+ "find_atom_ener" : 1.0 ,
888+ "find_atom_pref" : 0.0 ,
889+ }
890+
891+ @property
892+ def additional_data (self ) -> dict :
893+ return {
894+ "starter_learning_rate" : 1e-3 ,
895+ }
896+
897+ def build_tf (self , obj : Any , suffix : str ) -> tuple [list , dict ]:
898+ # use_default_pf=True is not supported by TensorFlow; skip_tf is True so
899+ # this method is never invoked, but the abstract base requires it.
900+ raise NotImplementedError
901+
902+ def eval_pt (self , pt_obj : Any ) -> Any :
903+ predict = {kk : numpy_to_torch (vv ) for kk , vv in self .predict .items ()}
904+ label = {kk : numpy_to_torch (vv ) for kk , vv in self .label .items ()}
905+ predict ["atom_energy" ] = predict .pop ("atom_ener" )
906+ _ , loss , more_loss = pt_obj (
907+ {},
908+ lambda : predict ,
909+ label ,
910+ self .natoms ,
911+ self .learning_rate ,
912+ mae = False ,
913+ )
914+ loss = torch_to_numpy (loss )
915+ more_loss = {kk : torch_to_numpy (vv ) for kk , vv in more_loss .items ()}
916+ return loss , more_loss
917+
918+ def eval_dp (self , dp_obj : Any ) -> Any :
919+ return dp_obj (
920+ self .learning_rate ,
921+ self .natoms ,
922+ self .predict_dpmodel_style ,
923+ self .label ,
924+ mae = False ,
925+ )
926+
927+ def eval_pt_expt (self , pt_expt_obj : Any ) -> Any :
928+ predict = {
929+ kk : numpy_to_torch (vv ) for kk , vv in self .predict_dpmodel_style .items ()
930+ }
931+ label = {kk : numpy_to_torch (vv ) for kk , vv in self .label .items ()}
932+ loss , more_loss = pt_expt_obj (
933+ self .learning_rate ,
934+ self .natoms ,
935+ predict ,
936+ label ,
937+ mae = False ,
938+ )
939+ loss = torch_to_numpy (loss )
940+ more_loss = {kk : torch_to_numpy (vv ) for kk , vv in more_loss .items ()}
941+ return loss , more_loss
942+
943+ def eval_jax (self , jax_obj : Any ) -> Any :
944+ predict = {kk : jnp .asarray (vv ) for kk , vv in self .predict_dpmodel_style .items ()}
945+ label = {kk : jnp .asarray (vv ) for kk , vv in self .label .items ()}
946+ loss , more_loss = jax_obj (
947+ self .learning_rate ,
948+ self .natoms ,
949+ predict ,
950+ label ,
951+ mae = False ,
952+ )
953+ loss = to_numpy_array (loss )
954+ more_loss = {kk : to_numpy_array (vv ) for kk , vv in more_loss .items ()}
955+ return loss , more_loss
956+
957+ def eval_array_api_strict (self , array_api_strict_obj : Any ) -> Any :
958+ predict = {
959+ kk : array_api_strict .asarray (vv )
960+ for kk , vv in self .predict_dpmodel_style .items ()
961+ }
962+ label = {kk : array_api_strict .asarray (vv ) for kk , vv in self .label .items ()}
963+ loss , more_loss = array_api_strict_obj (
964+ self .learning_rate ,
965+ self .natoms ,
966+ predict ,
967+ label ,
968+ mae = False ,
969+ )
970+ loss = to_numpy_array (loss )
971+ more_loss = {kk : to_numpy_array (vv ) for kk , vv in more_loss .items ()}
972+ return loss , more_loss
973+
974+ def extract_ret (self , ret : Any , backend ) -> dict [str , np .ndarray ]:
975+ loss = ret [0 ]
976+ result = {"loss" : np .atleast_1d (np .asarray (loss , dtype = np .float64 ))}
977+ if len (ret ) > 1 :
978+ more_loss = ret [1 ]
979+ for k in sorted (more_loss ):
980+ if k .startswith ("rmse_" ) or k .startswith ("mae_" ):
981+ result [k ] = np .atleast_1d (
982+ np .asarray (more_loss [k ], dtype = np .float64 )
983+ )
984+ return result
985+
986+ @property
987+ def rtol (self ) -> float :
988+ return 1e-10
989+
990+ @property
991+ def atol (self ) -> float :
992+ return 1e-10
0 commit comments