@@ -222,3 +222,331 @@ def atol(self) -> float:
222222 return 1e-4
223223 else :
224224 raise ValueError (f"Unknown precision: { precision } " )
225+
226+
227+ @parameterized (
228+ (True , False ), # resnet_dt
229+ ("float64" , "float32" ), # precision
230+ (True , False ), # mixed_types
231+ ([0 , 1 ],), # sel_type - only test with all types selected for consistency
232+ )
233+ class TestDipoleSelType (CommonTest , DipoleFittingTest , unittest .TestCase ):
234+ @property
235+ def data (self ) -> dict :
236+ (
237+ resnet_dt ,
238+ precision ,
239+ mixed_types ,
240+ sel_type ,
241+ ) = self .param
242+ return {
243+ "neuron" : [5 , 5 , 5 ],
244+ "resnet_dt" : resnet_dt ,
245+ "precision" : precision ,
246+ "seed" : 20240217 ,
247+ "sel_type" : sel_type , # For TF backend
248+ }
249+
250+ @property
251+ def skip_pt (self ) -> bool :
252+ (
253+ resnet_dt ,
254+ precision ,
255+ mixed_types ,
256+ sel_type ,
257+ ) = self .param
258+ return CommonTest .skip_pt
259+
260+ tf_class = DipoleFittingTF
261+ dp_class = DipoleFittingDP
262+ pt_class = DipoleFittingPT
263+ jax_class = DipoleFittingJAX
264+ array_api_strict_class = DipoleFittingArrayAPIStrict
265+ args = fitting_dipole ()
266+ skip_jax = not INSTALLED_JAX
267+ skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT
268+
269+ def setUp (self ) -> None :
270+ CommonTest .setUp (self )
271+
272+ self .ntypes = 2
273+ self .natoms = np .array ([6 , 6 , 2 , 4 ], dtype = np .int32 )
274+ self .inputs = np .ones ((1 , 6 , 20 ), dtype = GLOBAL_NP_FLOAT_PRECISION )
275+ self .gr = np .ones ((1 , 6 , 30 , 3 ), dtype = GLOBAL_NP_FLOAT_PRECISION )
276+ self .atype = np .array ([0 , 1 , 1 , 0 , 1 , 1 ], dtype = np .int32 )
277+ # inconsistent if not sorted
278+ self .atype .sort ()
279+
280+ @property
281+ def additional_data (self ) -> dict :
282+ (
283+ resnet_dt ,
284+ precision ,
285+ mixed_types ,
286+ sel_type ,
287+ ) = self .param
288+ # For DP/PT backends, use exclude_types instead of sel_type
289+ all_types = list (range (self .ntypes ))
290+ exclude_types = [t for t in all_types if t not in sel_type ]
291+ return {
292+ "ntypes" : self .ntypes ,
293+ "dim_descrpt" : self .inputs .shape [- 1 ],
294+ "mixed_types" : mixed_types ,
295+ "embedding_width" : 30 ,
296+ "exclude_types" : exclude_types , # For DP/PT backends
297+ }
298+
299+ def build_tf (self , obj : Any , suffix : str ) -> tuple [list , dict ]:
300+ (
301+ resnet_dt ,
302+ precision ,
303+ mixed_types ,
304+ sel_type ,
305+ ) = self .param
306+ return self .build_tf_fitting (
307+ obj ,
308+ self .inputs .ravel (),
309+ self .gr ,
310+ self .natoms ,
311+ self .atype ,
312+ None ,
313+ suffix ,
314+ )
315+
316+ def eval_pt (self , pt_obj : Any ) -> Any :
317+ (
318+ resnet_dt ,
319+ precision ,
320+ mixed_types ,
321+ sel_type ,
322+ ) = self .param
323+ return (
324+ pt_obj (
325+ torch .from_numpy (self .inputs ).to (device = PT_DEVICE ),
326+ torch .from_numpy (self .atype .reshape (1 , - 1 )).to (device = PT_DEVICE ),
327+ torch .from_numpy (self .gr ).to (device = PT_DEVICE ),
328+ None ,
329+ )["dipole" ]
330+ .detach ()
331+ .cpu ()
332+ .numpy ()
333+ )
334+
335+ def eval_dp (self , dp_obj : Any ) -> Any :
336+ (
337+ resnet_dt ,
338+ precision ,
339+ mixed_types ,
340+ sel_type ,
341+ ) = self .param
342+ return dp_obj (
343+ self .inputs ,
344+ self .atype .reshape (1 , - 1 ),
345+ self .gr ,
346+ None ,
347+ )["dipole" ]
348+
349+ def eval_jax (self , jax_obj : Any ) -> Any :
350+ return np .asarray (
351+ jax_obj (
352+ jnp .asarray (self .inputs ),
353+ jnp .asarray (self .atype .reshape (1 , - 1 )),
354+ jnp .asarray (self .gr ),
355+ None ,
356+ )["dipole" ]
357+ )
358+
359+ def eval_array_api_strict (self , array_api_strict_obj : Any ) -> Any :
360+ return to_numpy_array (
361+ array_api_strict_obj (
362+ array_api_strict .asarray (self .inputs ),
363+ array_api_strict .asarray (self .atype .reshape (1 , - 1 )),
364+ array_api_strict .asarray (self .gr ),
365+ None ,
366+ )["dipole" ]
367+ )
368+
369+ def extract_ret (self , ret : Any , backend ) -> tuple [np .ndarray , ...]:
370+ if backend == self .RefBackend .TF :
371+ # shape is not same
372+ ret = ret [0 ].reshape (- 1 , self .natoms [0 ], 1 )
373+ return (ret ,)
374+
375+ @property
376+ def rtol (self ) -> float :
377+ """Relative tolerance for comparing the return value."""
378+ (
379+ resnet_dt ,
380+ precision ,
381+ mixed_types ,
382+ sel_type ,
383+ ) = self .param
384+ if precision == "float64" :
385+ return 1e-10
386+ elif precision == "float32" :
387+ return 1e-4
388+ else :
389+ raise ValueError (f"Unknown precision: { precision } " )
390+
391+ @property
392+ def atol (self ) -> float :
393+ """Absolute tolerance for comparing the return value."""
394+ (
395+ resnet_dt ,
396+ precision ,
397+ mixed_types ,
398+ sel_type ,
399+ ) = self .param
400+ if precision == "float64" :
401+ return 1e-10
402+ elif precision == "float32" :
403+ return 1e-4
404+ else :
405+ raise ValueError (f"Unknown precision: { precision } " )
406+
407+ def test_sel_type_behavior (self ):
408+ """Test that sel_type parameter works correctly across backends."""
409+ (
410+ resnet_dt ,
411+ precision ,
412+ mixed_types ,
413+ sel_type ,
414+ ) = self .param
415+
416+ # Test TF backend if available
417+ if INSTALLED_TF :
418+ tf_obj = self .tf_class (** {** self .data , ** self .additional_data })
419+
420+ # Verify that only selected types have fitting nets
421+ if hasattr (tf_obj , "sel_type" ):
422+ self .assertEqual (set (tf_obj .sel_type ), set (sel_type ))
423+ if hasattr (tf_obj , "sel_mask" ):
424+ expected_mask = np .array ([i in sel_type for i in range (self .ntypes )])
425+ np .testing .assert_array_equal (tf_obj .sel_mask , expected_mask )
426+
427+ # Test DP backend
428+ all_types = list (range (self .ntypes ))
429+ exclude_types = [t for t in all_types if t not in sel_type ]
430+ dp_data = {** self .data }
431+ dp_data .pop ("sel_type" , None ) # Remove sel_type for DP backend
432+ dp_obj = self .dp_class (** {** dp_data , ** self .additional_data })
433+
434+ # Verify that exclude_types is set correctly
435+ if hasattr (dp_obj , "exclude_types" ):
436+ self .assertEqual (set (dp_obj .exclude_types ), set (exclude_types ))
437+
438+
439+ class TestDipoleSelTypeBehavior (unittest .TestCase ):
440+ """Test sel_type behavior specifically, without cross-backend consistency."""
441+
442+ def setUp (self ) -> None :
443+ self .ntypes = 2
444+ self .natoms = np .array ([6 , 6 , 2 , 4 ], dtype = np .int32 )
445+
446+ def test_tf_sel_type_all_types (self ):
447+ """Test that TF dipole fitting creates networks for all selected types."""
448+ if not INSTALLED_TF :
449+ self .skipTest ("TensorFlow not available" )
450+
451+ sel_type = [0 , 1 ] # Select all types
452+
453+ tf_obj = DipoleFittingTF (
454+ ntypes = self .ntypes ,
455+ dim_descrpt = 20 ,
456+ embedding_width = 30 ,
457+ neuron = [5 , 5 , 5 ],
458+ sel_type = sel_type ,
459+ )
460+
461+ # Verify sel_type is set correctly
462+ self .assertEqual (set (tf_obj .sel_type ), set (sel_type ))
463+
464+ # Verify sel_mask is correct
465+ expected_mask = np .array ([i in sel_type for i in range (self .ntypes )])
466+ np .testing .assert_array_equal (tf_obj .sel_mask , expected_mask )
467+
468+ def test_tf_sel_type_partial (self ):
469+ """Test that TF dipole fitting works with partial type selection."""
470+ if not INSTALLED_TF :
471+ self .skipTest ("TensorFlow not available" )
472+
473+ sel_type = [0 ] # Select only type 0
474+
475+ tf_obj = DipoleFittingTF (
476+ ntypes = self .ntypes ,
477+ dim_descrpt = 20 ,
478+ embedding_width = 30 ,
479+ neuron = [5 , 5 , 5 ],
480+ sel_type = sel_type ,
481+ )
482+
483+ # Verify sel_type is set correctly
484+ self .assertEqual (set (tf_obj .sel_type ), set (sel_type ))
485+
486+ # Verify sel_mask is correct
487+ expected_mask = np .array ([i in sel_type for i in range (self .ntypes )])
488+ np .testing .assert_array_equal (tf_obj .sel_mask , expected_mask )
489+
490+ def test_dp_exclude_types_behavior (self ):
491+ """Test that DP dipole fitting excludes the correct types."""
492+ sel_type = [0 ] # Select only type 0
493+ all_types = list (range (self .ntypes ))
494+ exclude_types = [t for t in all_types if t not in sel_type ]
495+
496+ dp_obj = DipoleFittingDP (
497+ ntypes = self .ntypes ,
498+ dim_descrpt = 20 ,
499+ embedding_width = 30 ,
500+ neuron = [5 , 5 , 5 ],
501+ exclude_types = exclude_types ,
502+ )
503+
504+ # Verify exclude_types is set correctly
505+ self .assertEqual (set (dp_obj .exclude_types ), set (exclude_types ))
506+
507+ # Verify get_sel_type returns the correct types
508+ selected_types = dp_obj .get_sel_type ()
509+ self .assertEqual (set (selected_types ), set (sel_type ))
510+
511+ def test_serialization_with_excluded_types (self ):
512+ """Test that sel_type is correctly stored in DipoleFittingSeA."""
513+ if not INSTALLED_TF :
514+ self .skipTest ("TensorFlow not available" )
515+
516+ # Test with excluding one type
517+ sel_type = [0 ] # Only select type 0, exclude type 1
518+
519+ tf_obj = DipoleFittingTF (
520+ ntypes = self .ntypes ,
521+ dim_descrpt = 20 ,
522+ embedding_width = 30 ,
523+ neuron = [5 , 5 , 5 ],
524+ sel_type = sel_type ,
525+ )
526+
527+ # Verify that sel_type is correctly stored
528+ self .assertEqual (tf_obj .sel_type , sel_type )
529+
530+ # Verify that sel_mask reflects the excluded types
531+ expected_mask = np .array ([True , False ]) # Only type 0 is selected
532+ np .testing .assert_array_equal (tf_obj .sel_mask , expected_mask )
533+
534+ def test_network_collection_none_handling (self ):
535+ """Test that NetworkCollection properly handles None networks."""
536+ from deepmd .dpmodel .utils .network import (
537+ NetworkCollection ,
538+ )
539+
540+ # Create a NetworkCollection with some None entries
541+ collection = NetworkCollection (ndim = 1 , ntypes = 2 )
542+
543+ # Test that None values can be set
544+ collection [0 ] = None
545+ collection [1 ] = None
546+
547+ # Test serialization with None values
548+ serialized = collection .serialize ()
549+ self .assertIn ("networks" , serialized )
550+ networks = serialized ["networks" ]
551+ self .assertEqual (len (networks ), 2 )
552+ self .assertTrue (all (net is None for net in networks ))
0 commit comments