Skip to content

Commit 5374bd2

Browse files
Copilotnjzjz
andcommitted
refactor: simplify test dipole sel_type parameters and remove redundant class
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent e857454 commit 5374bd2

1 file changed

Lines changed: 1 addition & 117 deletions

File tree

source/tests/consistent/fitting/test_dipole.py

Lines changed: 1 addition & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
(True, False), # resnet_dt
6262
("float64", "float32"), # precision
6363
(True, False), # mixed_types
64-
([], [0, 1]), # sel_type
64+
([], [0]), # sel_type
6565
)
6666
class TestDipole(CommonTest, DipoleFittingTest, unittest.TestCase):
6767
@property
@@ -241,119 +241,3 @@ def atol(self) -> float:
241241
return 1e-4
242242
else:
243243
raise ValueError(f"Unknown precision: {precision}")
244-
245-
246-
class TestDipoleSelTypeBehavior(unittest.TestCase):
247-
"""Test sel_type behavior specifically, without cross-backend consistency."""
248-
249-
def setUp(self) -> None:
250-
self.ntypes = 2
251-
self.natoms = np.array([6, 6, 2, 4], dtype=np.int32)
252-
253-
def test_tf_sel_type_all_types(self):
254-
"""Test that TF dipole fitting creates networks for all selected types."""
255-
if not INSTALLED_TF:
256-
self.skipTest("TensorFlow not available")
257-
258-
sel_type = [0, 1] # Select all types
259-
260-
tf_obj = DipoleFittingTF(
261-
ntypes=self.ntypes,
262-
dim_descrpt=20,
263-
embedding_width=30,
264-
neuron=[5, 5, 5],
265-
sel_type=sel_type,
266-
)
267-
268-
# Verify sel_type is set correctly
269-
self.assertEqual(set(tf_obj.sel_type), set(sel_type))
270-
271-
# Verify sel_mask is correct
272-
expected_mask = np.array([i in sel_type for i in range(self.ntypes)])
273-
np.testing.assert_array_equal(tf_obj.sel_mask, expected_mask)
274-
275-
def test_tf_sel_type_partial(self):
276-
"""Test that TF dipole fitting works with partial type selection."""
277-
if not INSTALLED_TF:
278-
self.skipTest("TensorFlow not available")
279-
280-
sel_type = [0] # Select only type 0
281-
282-
tf_obj = DipoleFittingTF(
283-
ntypes=self.ntypes,
284-
dim_descrpt=20,
285-
embedding_width=30,
286-
neuron=[5, 5, 5],
287-
sel_type=sel_type,
288-
)
289-
290-
# Verify sel_type is set correctly
291-
self.assertEqual(set(tf_obj.sel_type), set(sel_type))
292-
293-
# Verify sel_mask is correct
294-
expected_mask = np.array([i in sel_type for i in range(self.ntypes)])
295-
np.testing.assert_array_equal(tf_obj.sel_mask, expected_mask)
296-
297-
def test_dp_exclude_types_behavior(self):
298-
"""Test that DP dipole fitting excludes the correct types."""
299-
sel_type = [0] # Select only type 0
300-
all_types = list(range(self.ntypes))
301-
exclude_types = [t for t in all_types if t not in sel_type]
302-
303-
dp_obj = DipoleFittingDP(
304-
ntypes=self.ntypes,
305-
dim_descrpt=20,
306-
embedding_width=30,
307-
neuron=[5, 5, 5],
308-
exclude_types=exclude_types,
309-
)
310-
311-
# Verify exclude_types is set correctly
312-
self.assertEqual(set(dp_obj.exclude_types), set(exclude_types))
313-
314-
# Verify get_sel_type returns the correct types
315-
selected_types = dp_obj.get_sel_type()
316-
self.assertEqual(set(selected_types), set(sel_type))
317-
318-
def test_serialization_with_excluded_types(self):
319-
"""Test that sel_type is correctly stored in DipoleFittingSeA."""
320-
if not INSTALLED_TF:
321-
self.skipTest("TensorFlow not available")
322-
323-
# Test with excluding one type
324-
sel_type = [0] # Only select type 0, exclude type 1
325-
326-
tf_obj = DipoleFittingTF(
327-
ntypes=self.ntypes,
328-
dim_descrpt=20,
329-
embedding_width=30,
330-
neuron=[5, 5, 5],
331-
sel_type=sel_type,
332-
)
333-
334-
# Verify that sel_type is correctly stored
335-
self.assertEqual(tf_obj.sel_type, sel_type)
336-
337-
# Verify that sel_mask reflects the excluded types
338-
expected_mask = np.array([True, False]) # Only type 0 is selected
339-
np.testing.assert_array_equal(tf_obj.sel_mask, expected_mask)
340-
341-
def test_network_collection_none_handling(self):
342-
"""Test that NetworkCollection properly handles None networks."""
343-
from deepmd.dpmodel.utils.network import (
344-
NetworkCollection,
345-
)
346-
347-
# Create a NetworkCollection with some None entries
348-
collection = NetworkCollection(ndim=1, ntypes=2)
349-
350-
# Test that None values can be set
351-
collection[0] = None
352-
collection[1] = None
353-
354-
# Test serialization with None values
355-
serialized = collection.serialize()
356-
self.assertIn("networks", serialized)
357-
networks = serialized["networks"]
358-
self.assertEqual(len(networks), 2)
359-
self.assertTrue(all(net is None for net in networks))

0 commit comments

Comments
 (0)