Skip to content

Commit b9e5bed

Browse files
Copilotnjzjz
andcommitted
feat: add comprehensive tests for sel_type parameter in dipole fitting
- Add TestDipoleSelType class for cross-backend consistency testing with sel_type=[0,1] - Add TestDipoleSelTypeBehavior class with specific sel_type functionality tests: * test_tf_sel_type_all_types: verify TF dipole fitting with all types selected * test_tf_sel_type_partial: verify TF dipole fitting with partial type selection * test_dp_exclude_types_behavior: verify DP dipole fitting exclude_types behavior * test_serialization_with_excluded_types: verify sel_type is properly stored * test_network_collection_none_handling: verify NetworkCollection handles None networks - Fix TF fitting.py to handle None networks in deserialization (skip instead of assert) - All tests pass and verify the sel_type bug fix works correctly Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 721a2c3 commit b9e5bed

2 files changed

Lines changed: 331 additions & 1 deletion

File tree

deepmd/tf/fit/fitting.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,9 @@ def deserialize_network(cls, data: dict, suffix: str = "") -> dict:
244244
else:
245245
raise ValueError(f"Invalid ndim: {fittings.ndim}")
246246
network = fittings[net_idx]
247-
assert network is not None
247+
if network is None:
248+
# Skip types that are not selected (when sel_type is used)
249+
continue
248250
for layer_idx, layer in enumerate(network.layers):
249251
if layer_idx == len(network.layers) - 1:
250252
layer_name = "final_layer"

source/tests/consistent/fitting/test_dipole.py

Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,3 +222,331 @@ def atol(self) -> float:
222222
return 1e-4
223223
else:
224224
raise ValueError(f"Unknown precision: {precision}")
225+
226+
227+
@parameterized(
228+
(True, False), # resnet_dt
229+
("float64", "float32"), # precision
230+
(True, False), # mixed_types
231+
([0, 1],), # sel_type - only test with all types selected for consistency
232+
)
233+
class TestDipoleSelType(CommonTest, DipoleFittingTest, unittest.TestCase):
234+
@property
235+
def data(self) -> dict:
236+
(
237+
resnet_dt,
238+
precision,
239+
mixed_types,
240+
sel_type,
241+
) = self.param
242+
return {
243+
"neuron": [5, 5, 5],
244+
"resnet_dt": resnet_dt,
245+
"precision": precision,
246+
"seed": 20240217,
247+
"sel_type": sel_type, # For TF backend
248+
}
249+
250+
@property
251+
def skip_pt(self) -> bool:
252+
(
253+
resnet_dt,
254+
precision,
255+
mixed_types,
256+
sel_type,
257+
) = self.param
258+
return CommonTest.skip_pt
259+
260+
tf_class = DipoleFittingTF
261+
dp_class = DipoleFittingDP
262+
pt_class = DipoleFittingPT
263+
jax_class = DipoleFittingJAX
264+
array_api_strict_class = DipoleFittingArrayAPIStrict
265+
args = fitting_dipole()
266+
skip_jax = not INSTALLED_JAX
267+
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT
268+
269+
def setUp(self) -> None:
270+
CommonTest.setUp(self)
271+
272+
self.ntypes = 2
273+
self.natoms = np.array([6, 6, 2, 4], dtype=np.int32)
274+
self.inputs = np.ones((1, 6, 20), dtype=GLOBAL_NP_FLOAT_PRECISION)
275+
self.gr = np.ones((1, 6, 30, 3), dtype=GLOBAL_NP_FLOAT_PRECISION)
276+
self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32)
277+
# inconsistent if not sorted
278+
self.atype.sort()
279+
280+
@property
281+
def additional_data(self) -> dict:
282+
(
283+
resnet_dt,
284+
precision,
285+
mixed_types,
286+
sel_type,
287+
) = self.param
288+
# For DP/PT backends, use exclude_types instead of sel_type
289+
all_types = list(range(self.ntypes))
290+
exclude_types = [t for t in all_types if t not in sel_type]
291+
return {
292+
"ntypes": self.ntypes,
293+
"dim_descrpt": self.inputs.shape[-1],
294+
"mixed_types": mixed_types,
295+
"embedding_width": 30,
296+
"exclude_types": exclude_types, # For DP/PT backends
297+
}
298+
299+
def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]:
300+
(
301+
resnet_dt,
302+
precision,
303+
mixed_types,
304+
sel_type,
305+
) = self.param
306+
return self.build_tf_fitting(
307+
obj,
308+
self.inputs.ravel(),
309+
self.gr,
310+
self.natoms,
311+
self.atype,
312+
None,
313+
suffix,
314+
)
315+
316+
def eval_pt(self, pt_obj: Any) -> Any:
317+
(
318+
resnet_dt,
319+
precision,
320+
mixed_types,
321+
sel_type,
322+
) = self.param
323+
return (
324+
pt_obj(
325+
torch.from_numpy(self.inputs).to(device=PT_DEVICE),
326+
torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_DEVICE),
327+
torch.from_numpy(self.gr).to(device=PT_DEVICE),
328+
None,
329+
)["dipole"]
330+
.detach()
331+
.cpu()
332+
.numpy()
333+
)
334+
335+
def eval_dp(self, dp_obj: Any) -> Any:
336+
(
337+
resnet_dt,
338+
precision,
339+
mixed_types,
340+
sel_type,
341+
) = self.param
342+
return dp_obj(
343+
self.inputs,
344+
self.atype.reshape(1, -1),
345+
self.gr,
346+
None,
347+
)["dipole"]
348+
349+
def eval_jax(self, jax_obj: Any) -> Any:
350+
return np.asarray(
351+
jax_obj(
352+
jnp.asarray(self.inputs),
353+
jnp.asarray(self.atype.reshape(1, -1)),
354+
jnp.asarray(self.gr),
355+
None,
356+
)["dipole"]
357+
)
358+
359+
def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
360+
return to_numpy_array(
361+
array_api_strict_obj(
362+
array_api_strict.asarray(self.inputs),
363+
array_api_strict.asarray(self.atype.reshape(1, -1)),
364+
array_api_strict.asarray(self.gr),
365+
None,
366+
)["dipole"]
367+
)
368+
369+
def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
370+
if backend == self.RefBackend.TF:
371+
# shape is not same
372+
ret = ret[0].reshape(-1, self.natoms[0], 1)
373+
return (ret,)
374+
375+
@property
376+
def rtol(self) -> float:
377+
"""Relative tolerance for comparing the return value."""
378+
(
379+
resnet_dt,
380+
precision,
381+
mixed_types,
382+
sel_type,
383+
) = self.param
384+
if precision == "float64":
385+
return 1e-10
386+
elif precision == "float32":
387+
return 1e-4
388+
else:
389+
raise ValueError(f"Unknown precision: {precision}")
390+
391+
@property
392+
def atol(self) -> float:
393+
"""Absolute tolerance for comparing the return value."""
394+
(
395+
resnet_dt,
396+
precision,
397+
mixed_types,
398+
sel_type,
399+
) = self.param
400+
if precision == "float64":
401+
return 1e-10
402+
elif precision == "float32":
403+
return 1e-4
404+
else:
405+
raise ValueError(f"Unknown precision: {precision}")
406+
407+
def test_sel_type_behavior(self):
408+
"""Test that sel_type parameter works correctly across backends."""
409+
(
410+
resnet_dt,
411+
precision,
412+
mixed_types,
413+
sel_type,
414+
) = self.param
415+
416+
# Test TF backend if available
417+
if INSTALLED_TF:
418+
tf_obj = self.tf_class(**{**self.data, **self.additional_data})
419+
420+
# Verify that only selected types have fitting nets
421+
if hasattr(tf_obj, "sel_type"):
422+
self.assertEqual(set(tf_obj.sel_type), set(sel_type))
423+
if hasattr(tf_obj, "sel_mask"):
424+
expected_mask = np.array([i in sel_type for i in range(self.ntypes)])
425+
np.testing.assert_array_equal(tf_obj.sel_mask, expected_mask)
426+
427+
# Test DP backend
428+
all_types = list(range(self.ntypes))
429+
exclude_types = [t for t in all_types if t not in sel_type]
430+
dp_data = {**self.data}
431+
dp_data.pop("sel_type", None) # Remove sel_type for DP backend
432+
dp_obj = self.dp_class(**{**dp_data, **self.additional_data})
433+
434+
# Verify that exclude_types is set correctly
435+
if hasattr(dp_obj, "exclude_types"):
436+
self.assertEqual(set(dp_obj.exclude_types), set(exclude_types))
437+
438+
439+
class TestDipoleSelTypeBehavior(unittest.TestCase):
440+
"""Test sel_type behavior specifically, without cross-backend consistency."""
441+
442+
def setUp(self) -> None:
443+
self.ntypes = 2
444+
self.natoms = np.array([6, 6, 2, 4], dtype=np.int32)
445+
446+
def test_tf_sel_type_all_types(self):
447+
"""Test that TF dipole fitting creates networks for all selected types."""
448+
if not INSTALLED_TF:
449+
self.skipTest("TensorFlow not available")
450+
451+
sel_type = [0, 1] # Select all types
452+
453+
tf_obj = DipoleFittingTF(
454+
ntypes=self.ntypes,
455+
dim_descrpt=20,
456+
embedding_width=30,
457+
neuron=[5, 5, 5],
458+
sel_type=sel_type,
459+
)
460+
461+
# Verify sel_type is set correctly
462+
self.assertEqual(set(tf_obj.sel_type), set(sel_type))
463+
464+
# Verify sel_mask is correct
465+
expected_mask = np.array([i in sel_type for i in range(self.ntypes)])
466+
np.testing.assert_array_equal(tf_obj.sel_mask, expected_mask)
467+
468+
def test_tf_sel_type_partial(self):
469+
"""Test that TF dipole fitting works with partial type selection."""
470+
if not INSTALLED_TF:
471+
self.skipTest("TensorFlow not available")
472+
473+
sel_type = [0] # Select only type 0
474+
475+
tf_obj = DipoleFittingTF(
476+
ntypes=self.ntypes,
477+
dim_descrpt=20,
478+
embedding_width=30,
479+
neuron=[5, 5, 5],
480+
sel_type=sel_type,
481+
)
482+
483+
# Verify sel_type is set correctly
484+
self.assertEqual(set(tf_obj.sel_type), set(sel_type))
485+
486+
# Verify sel_mask is correct
487+
expected_mask = np.array([i in sel_type for i in range(self.ntypes)])
488+
np.testing.assert_array_equal(tf_obj.sel_mask, expected_mask)
489+
490+
def test_dp_exclude_types_behavior(self):
491+
"""Test that DP dipole fitting excludes the correct types."""
492+
sel_type = [0] # Select only type 0
493+
all_types = list(range(self.ntypes))
494+
exclude_types = [t for t in all_types if t not in sel_type]
495+
496+
dp_obj = DipoleFittingDP(
497+
ntypes=self.ntypes,
498+
dim_descrpt=20,
499+
embedding_width=30,
500+
neuron=[5, 5, 5],
501+
exclude_types=exclude_types,
502+
)
503+
504+
# Verify exclude_types is set correctly
505+
self.assertEqual(set(dp_obj.exclude_types), set(exclude_types))
506+
507+
# Verify get_sel_type returns the correct types
508+
selected_types = dp_obj.get_sel_type()
509+
self.assertEqual(set(selected_types), set(sel_type))
510+
511+
def test_serialization_with_excluded_types(self):
512+
"""Test that sel_type is correctly stored in DipoleFittingSeA."""
513+
if not INSTALLED_TF:
514+
self.skipTest("TensorFlow not available")
515+
516+
# Test with excluding one type
517+
sel_type = [0] # Only select type 0, exclude type 1
518+
519+
tf_obj = DipoleFittingTF(
520+
ntypes=self.ntypes,
521+
dim_descrpt=20,
522+
embedding_width=30,
523+
neuron=[5, 5, 5],
524+
sel_type=sel_type,
525+
)
526+
527+
# Verify that sel_type is correctly stored
528+
self.assertEqual(tf_obj.sel_type, sel_type)
529+
530+
# Verify that sel_mask reflects the excluded types
531+
expected_mask = np.array([True, False]) # Only type 0 is selected
532+
np.testing.assert_array_equal(tf_obj.sel_mask, expected_mask)
533+
534+
def test_network_collection_none_handling(self):
535+
"""Test that NetworkCollection properly handles None networks."""
536+
from deepmd.dpmodel.utils.network import (
537+
NetworkCollection,
538+
)
539+
540+
# Create a NetworkCollection with some None entries
541+
collection = NetworkCollection(ndim=1, ntypes=2)
542+
543+
# Test that None values can be set
544+
collection[0] = None
545+
collection[1] = None
546+
547+
# Test serialization with None values
548+
serialized = collection.serialize()
549+
self.assertIn("networks", serialized)
550+
networks = serialized["networks"]
551+
self.assertEqual(len(networks), 2)
552+
self.assertTrue(all(net is None for net in networks))

0 commit comments

Comments
 (0)