6161 (True , False ), # resnet_dt
6262 ("float64" , "float32" ), # precision
6363 (True , False ), # mixed_types
64- ([] , [0 ]), # sel_type
64+ (None , [0 ]), # sel_type
6565)
6666class TestDipole (CommonTest , DipoleFittingTest , unittest .TestCase ):
6767 @property
@@ -76,13 +76,32 @@ def data(self) -> dict:
7676 "neuron" : [5 , 5 , 5 ],
7777 "resnet_dt" : resnet_dt ,
7878 "precision" : precision ,
79+ "sel_type" : sel_type ,
7980 "seed" : 20240217 ,
8081 }
81- # Only add sel_type if it's not empty (for TF backend compatibility)
82- if sel_type :
83- data ["sel_type" ] = sel_type
8482 return data
8583
84+ def pass_data_to_cls (self , cls , data ) -> Any :
85+ """Pass data to the class."""
86+ if cls not in (self .tf_class ,):
87+ sel_type = data .pop ("sel_type" , None )
88+ if sel_type is not None :
89+ all_types = list (range (self .ntypes ))
90+ exclude_types = [t for t in all_types if t not in sel_type ]
91+ data ["exclude_types" ] = exclude_types
92+ return cls (** data , ** self .additional_data )
93+
94+ @property
95+ def skip_tf (self ) -> bool :
96+ (
97+ resnet_dt ,
98+ precision ,
99+ mixed_types ,
100+ sel_type ,
101+ ) = self .param
102+ # mixed_types + sel_type is not supported
103+ return CommonTest .skip_tf or (mixed_types and sel_type is not None )
104+
86105 @property
87106 def skip_pt (self ) -> bool :
88107 (
@@ -127,11 +146,6 @@ def additional_data(self) -> dict:
127146 "mixed_types" : mixed_types ,
128147 "embedding_width" : 30 ,
129148 }
130- # For DP/PT backends, use exclude_types instead of sel_type
131- if sel_type :
132- all_types = list (range (self .ntypes ))
133- exclude_types = [t for t in all_types if t not in sel_type ]
134- additional ["exclude_types" ] = exclude_types
135149 return additional
136150
137151 def build_tf (self , obj : Any , suffix : str ) -> tuple [list , dict ]:
@@ -241,3 +255,39 @@ def atol(self) -> float:
241255 return 1e-4
242256 else :
243257 raise ValueError (f"Unknown precision: { precision } " )
258+
259+ def test_tf_consistent_with_ref (self ) -> None :
260+ """Test whether TF and reference are consistent."""
261+ # Special handle for sel_types
262+ if self .skip_tf :
263+ self .skipTest ("Unsupported backend" )
264+ ref_backend = self .get_reference_backend ()
265+ if ref_backend == self .RefBackend .TF :
266+ self .skipTest ("Reference is self" )
267+ ret1 , data1 = self .get_reference_ret_serialization (ref_backend )
268+ ret1 = self .extract_ret (ret1 , ref_backend )
269+ self .reset_unique_id ()
270+ tf_obj = self .tf_class .deserialize (data1 , suffix = self .unique_id )
271+ ret2 , data2 = self .get_tf_ret_serialization_from_cls (tf_obj )
272+ ret2 = self .extract_ret (ret2 , self .RefBackend .TF )
273+ if tf_obj .__class__ .__name__ .startswith (("Polar" , "Dipole" , "DOS" )):
274+ # tf, pt serialization mismatch
275+ common_keys = set (data1 .keys ()) & set (data2 .keys ())
276+ data1 = {k : data1 [k ] for k in common_keys }
277+ data2 = {k : data2 [k ] for k in common_keys }
278+
279+ # not comparing version
280+ data1 .pop ("@version" )
281+ data2 .pop ("@version" )
282+
283+ if tf_obj .__class__ .__name__ .startswith ("Polar" ):
284+ data1 ["@variables" ].pop ("bias_atom_e" )
285+ for ii , networks in enumerate (data2 ["nets" ]["networks" ]):
286+ if networks is None :
287+ data1 ["nets" ]["networks" ][ii ] = None
288+ np .testing .assert_equal (data1 , data2 )
289+ for rr1 , rr2 in zip (ret1 , ret2 ):
290+ np .testing .assert_allclose (
291+ rr1 .ravel ()[: rr2 .size ], rr2 .ravel (), rtol = self .rtol , atol = self .atol
292+ )
293+ assert rr1 .dtype == rr2 .dtype , f"{ rr1 .dtype } != { rr2 .dtype } "
0 commit comments