6161 (True , False ), # resnet_dt
6262 ("float64" , "float32" ), # precision
6363 (True , False ), # mixed_types
64+ ([], [0 , 1 ]), # sel_type
6465)
6566class TestDipole (CommonTest , DipoleFittingTest , unittest .TestCase ):
66- @property
67- def data (self ) -> dict :
68- (
69- resnet_dt ,
70- precision ,
71- mixed_types ,
72- ) = self .param
73- return {
74- "neuron" : [5 , 5 , 5 ],
75- "resnet_dt" : resnet_dt ,
76- "precision" : precision ,
77- "seed" : 20240217 ,
78- }
79-
80- @property
81- def skip_pt (self ) -> bool :
82- (
83- resnet_dt ,
84- precision ,
85- mixed_types ,
86- ) = self .param
87- return CommonTest .skip_pt
88-
89- tf_class = DipoleFittingTF
90- dp_class = DipoleFittingDP
91- pt_class = DipoleFittingPT
92- jax_class = DipoleFittingJAX
93- array_api_strict_class = DipoleFittingArrayAPIStrict
94- args = fitting_dipole ()
95- skip_jax = not INSTALLED_JAX
96- skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT
97-
98- def setUp (self ) -> None :
99- CommonTest .setUp (self )
100-
101- self .ntypes = 2
102- self .natoms = np .array ([6 , 6 , 2 , 4 ], dtype = np .int32 )
103- self .inputs = np .ones ((1 , 6 , 20 ), dtype = GLOBAL_NP_FLOAT_PRECISION )
104- self .gr = np .ones ((1 , 6 , 30 , 3 ), dtype = GLOBAL_NP_FLOAT_PRECISION )
105- self .atype = np .array ([0 , 1 , 1 , 0 , 1 , 1 ], dtype = np .int32 )
106- # inconsistent if not sorted
107- self .atype .sort ()
108-
109- @property
110- def additional_data (self ) -> dict :
111- (
112- resnet_dt ,
113- precision ,
114- mixed_types ,
115- ) = self .param
116- return {
117- "ntypes" : self .ntypes ,
118- "dim_descrpt" : self .inputs .shape [- 1 ],
119- "mixed_types" : mixed_types ,
120- "embedding_width" : 30 ,
121- }
122-
123- def build_tf (self , obj : Any , suffix : str ) -> tuple [list , dict ]:
124- (
125- resnet_dt ,
126- precision ,
127- mixed_types ,
128- ) = self .param
129- return self .build_tf_fitting (
130- obj ,
131- self .inputs .ravel (),
132- self .gr ,
133- self .natoms ,
134- self .atype ,
135- None ,
136- suffix ,
137- )
138-
139- def eval_pt (self , pt_obj : Any ) -> Any :
140- (
141- resnet_dt ,
142- precision ,
143- mixed_types ,
144- ) = self .param
145- return (
146- pt_obj (
147- torch .from_numpy (self .inputs ).to (device = PT_DEVICE ),
148- torch .from_numpy (self .atype .reshape (1 , - 1 )).to (device = PT_DEVICE ),
149- torch .from_numpy (self .gr ).to (device = PT_DEVICE ),
150- None ,
151- )["dipole" ]
152- .detach ()
153- .cpu ()
154- .numpy ()
155- )
156-
157- def eval_dp (self , dp_obj : Any ) -> Any :
158- (
159- resnet_dt ,
160- precision ,
161- mixed_types ,
162- ) = self .param
163- return dp_obj (
164- self .inputs ,
165- self .atype .reshape (1 , - 1 ),
166- self .gr ,
167- None ,
168- )["dipole" ]
169-
170- def eval_jax (self , jax_obj : Any ) -> Any :
171- return np .asarray (
172- jax_obj (
173- jnp .asarray (self .inputs ),
174- jnp .asarray (self .atype .reshape (1 , - 1 )),
175- jnp .asarray (self .gr ),
176- None ,
177- )["dipole" ]
178- )
179-
180- def eval_array_api_strict (self , array_api_strict_obj : Any ) -> Any :
181- return to_numpy_array (
182- array_api_strict_obj (
183- array_api_strict .asarray (self .inputs ),
184- array_api_strict .asarray (self .atype .reshape (1 , - 1 )),
185- array_api_strict .asarray (self .gr ),
186- None ,
187- )["dipole" ]
188- )
189-
190- def extract_ret (self , ret : Any , backend ) -> tuple [np .ndarray , ...]:
191- if backend == self .RefBackend .TF :
192- # shape is not same
193- ret = ret [0 ].reshape (- 1 , self .natoms [0 ], 1 )
194- return (ret ,)
195-
196- @property
197- def rtol (self ) -> float :
198- """Relative tolerance for comparing the return value."""
199- (
200- resnet_dt ,
201- precision ,
202- mixed_types ,
203- ) = self .param
204- if precision == "float64" :
205- return 1e-10
206- elif precision == "float32" :
207- return 1e-4
208- else :
209- raise ValueError (f"Unknown precision: { precision } " )
210-
211- @property
212- def atol (self ) -> float :
213- """Absolute tolerance for comparing the return value."""
214- (
215- resnet_dt ,
216- precision ,
217- mixed_types ,
218- ) = self .param
219- if precision == "float64" :
220- return 1e-10
221- elif precision == "float32" :
222- return 1e-4
223- else :
224- raise ValueError (f"Unknown precision: { precision } " )
225-
226-
227- @parameterized (
228- (True , False ), # resnet_dt
229- ("float64" , "float32" ), # precision
230- (True , False ), # mixed_types
231- ([0 , 1 ],), # sel_type - only test with all types selected for consistency
232- )
233- class TestDipoleSelType (CommonTest , DipoleFittingTest , unittest .TestCase ):
23467 @property
23568 def data (self ) -> dict :
23669 (
@@ -239,13 +72,16 @@ def data(self) -> dict:
23972 mixed_types ,
24073 sel_type ,
24174 ) = self .param
242- return {
75+ data = {
24376 "neuron" : [5 , 5 , 5 ],
24477 "resnet_dt" : resnet_dt ,
24578 "precision" : precision ,
24679 "seed" : 20240217 ,
247- "sel_type" : sel_type , # For TF backend
24880 }
81+ # Only add sel_type if it's not empty (for TF backend compatibility)
82+ if sel_type :
83+ data ["sel_type" ] = sel_type
84+ return data
24985
25086 @property
25187 def skip_pt (self ) -> bool :
@@ -285,16 +121,18 @@ def additional_data(self) -> dict:
285121 mixed_types ,
286122 sel_type ,
287123 ) = self .param
288- # For DP/PT backends, use exclude_types instead of sel_type
289- all_types = list (range (self .ntypes ))
290- exclude_types = [t for t in all_types if t not in sel_type ]
291- return {
124+ additional = {
292125 "ntypes" : self .ntypes ,
293126 "dim_descrpt" : self .inputs .shape [- 1 ],
294127 "mixed_types" : mixed_types ,
295128 "embedding_width" : 30 ,
296- "exclude_types" : exclude_types , # For DP/PT backends
297129 }
130+ # For DP/PT backends, use exclude_types instead of sel_type
131+ if sel_type :
132+ all_types = list (range (self .ntypes ))
133+ exclude_types = [t for t in all_types if t not in sel_type ]
134+ additional ["exclude_types" ] = exclude_types
135+ return additional
298136
299137 def build_tf (self , obj : Any , suffix : str ) -> tuple [list , dict ]:
300138 (
@@ -404,37 +242,6 @@ def atol(self) -> float:
404242 else :
405243 raise ValueError (f"Unknown precision: { precision } " )
406244
407- def test_sel_type_behavior (self ):
408- """Test that sel_type parameter works correctly across backends."""
409- (
410- resnet_dt ,
411- precision ,
412- mixed_types ,
413- sel_type ,
414- ) = self .param
415-
416- # Test TF backend if available
417- if INSTALLED_TF :
418- tf_obj = self .tf_class (** {** self .data , ** self .additional_data })
419-
420- # Verify that only selected types have fitting nets
421- if hasattr (tf_obj , "sel_type" ):
422- self .assertEqual (set (tf_obj .sel_type ), set (sel_type ))
423- if hasattr (tf_obj , "sel_mask" ):
424- expected_mask = np .array ([i in sel_type for i in range (self .ntypes )])
425- np .testing .assert_array_equal (tf_obj .sel_mask , expected_mask )
426-
427- # Test DP backend
428- all_types = list (range (self .ntypes ))
429- exclude_types = [t for t in all_types if t not in sel_type ]
430- dp_data = {** self .data }
431- dp_data .pop ("sel_type" , None ) # Remove sel_type for DP backend
432- dp_obj = self .dp_class (** {** dp_data , ** self .additional_data })
433-
434- # Verify that exclude_types is set correctly
435- if hasattr (dp_obj , "exclude_types" ):
436- self .assertEqual (set (dp_obj .exclude_types ), set (exclude_types ))
437-
438245
439246class TestDipoleSelTypeBehavior (unittest .TestCase ):
440247 """Test sel_type behavior specifically, without cross-backend consistency."""
0 commit comments